summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go8
-rw-r--r--pkg/buffer/view_test.go18
-rw-r--r--pkg/eventfd/BUILD22
-rw-r--r--pkg/eventfd/eventfd.go115
-rw-r--r--pkg/eventfd/eventfd_test.go75
-rw-r--r--pkg/ring0/defs.go12
-rw-r--r--pkg/ring0/defs_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.s177
-rw-r--r--pkg/ring0/kernel.go5
-rw-r--r--pkg/ring0/kernel_amd64.go23
-rw-r--r--pkg/ring0/lib_amd64.go42
-rw-r--r--pkg/ring0/lib_amd64.s23
-rw-r--r--pkg/ring0/offsets_amd64.go3
-rw-r--r--pkg/safecopy/BUILD2
-rw-r--r--pkg/safecopy/safecopy.go5
-rw-r--r--pkg/safecopy/safecopy_unsafe.go37
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go3
-rw-r--r--pkg/sentry/fs/host/inode.go2
-rw-r--r--pkg/sentry/fs/inotify.go2
-rw-r--r--pkg/sentry/fs/lock/lock.go2
-rw-r--r--pkg/sentry/fs/proc/sys.go3
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go2
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go16
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go2
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go9
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/hostmm.go42
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go15
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go2
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/kernel/kernel.go46
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go2
-rw-r--r--pkg/sentry/kernel/task.go2
-rw-r--r--pkg/sentry/kernel/task_log.go6
-rw-r--r--pkg/sentry/kernel/threads.go6
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go4
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go8
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go10
-rw-r--r--pkg/sentry/platform/kvm/kvm.go6
-rw-r--r--pkg/sentry/platform/kvm/machine.go4
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go11
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go5
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go3
-rw-r--r--pkg/sentry/socket/BUILD5
-rw-r--r--pkg/sentry/socket/control/control.go13
-rw-r--r--pkg/sentry/socket/control/control_test.go2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go28
-rw-r--r--pkg/sentry/socket/netfilter/targets.go2
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go112
-rw-r--r--pkg/sentry/socket/netstack/netstack_state.go31
-rw-r--r--pkg/sentry/socket/socket.go13
-rw-r--r--pkg/sentry/socket/socket_state.go27
-rw-r--r--pkg/sentry/strace/strace.go4
-rw-r--r--pkg/sentry/vfs/epoll.go6
-rw-r--r--pkg/shim/service.go20
-rw-r--r--pkg/shim/service_test.go20
-rw-r--r--pkg/sighandling/BUILD (renamed from pkg/sentry/sighandling/BUILD)2
-rw-r--r--pkg/sighandling/sighandling.go (renamed from pkg/sentry/sighandling/sighandling.go)0
-rw-r--r--pkg/sighandling/sighandling_unsafe.go (renamed from pkg/sentry/sighandling/sighandling_unsafe.go)34
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/wait.go58
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go16
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD30
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go199
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go30
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go142
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go175
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go230
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go333
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go103
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go20
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go16
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go14
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go17
-rw-r--r--pkg/tcpip/stack/conntrack.go589
-rw-r--r--pkg/tcpip/stack/iptables.go132
-rw-r--r--pkg/tcpip/stack/iptables_state.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go192
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/packet_buffer.go30
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go38
-rw-r--r--pkg/tcpip/stack/stack_test.go11
-rw-r--r--pkg/tcpip/stack/tcp.go6
-rw-r--r--pkg/tcpip/tcpip.go22
-rw-r--r--pkg/tcpip/tcpip_state.go27
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go469
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go42
-rw-r--r--pkg/tcpip/transport/tcp/accept.go308
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go101
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/snd.go127
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go255
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/unet/BUILD1
-rw-r--r--pkg/unet/unet.go20
-rw-r--r--pkg/unet/unet_unsafe.go2
111 files changed, 3744 insertions, 1466 deletions
diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index cbaf716bb..6b9a67adc 100644
--- a/pkg/atomicbitops/atomicbitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -16,28 +16,28 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
ANDL AX, 0(BX)
RET
-TEXT ·OrUint32(SB),$0-12
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
ORL AX, 0(BX)
RET
-TEXT ·XorUint32(SB),$0-12
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
XORL AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVQ addr+0(FP), DI
MOVL old+8(FP), AX
MOVL new+12(FP), DX
@@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVL AX, ret+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
ANDQ AX, 0(BX)
RET
-TEXT ·OrUint64(SB),$0-16
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
ORQ AX, 0(BX)
RET
-TEXT ·XorUint64(SB),$0-16
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
XORQ AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVQ addr+0(FP), DI
MOVQ old+8(FP), AX
MOVQ new+16(FP), DX
diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 5c780851b..644a6bca5 100644
--- a/pkg/atomicbitops/atomicbitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -16,7 +16,7 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -26,7 +26,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint32(SB),$0-12
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -36,7 +36,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint32(SB),$0-12
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -46,7 +46,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
@@ -60,7 +60,7 @@ done:
MOVW R3, prev+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -70,7 +70,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint64(SB),$0-16
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -80,7 +80,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint64(SB),$0-16
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -90,7 +90,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 474c0c815..af6b1362e 100644
--- a/pkg/atomicbitops/atomicbitops_noasm.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -21,6 +21,7 @@ import (
"sync/atomic"
)
+//go:nosplit
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,6 +32,7 @@ func AndUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -41,6 +43,7 @@ func OrUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -51,6 +54,7 @@ func XorUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -63,6 +67,7 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
+//go:nosplit
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -73,6 +78,7 @@ func AndUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -83,6 +89,7 @@ func OrUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -93,6 +100,7 @@ func XorUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
index 796efa240..59784eacb 100644
--- a/pkg/buffer/view_test.go
+++ b/pkg/buffer/view_test.go
@@ -509,6 +509,24 @@ func TestView(t *testing.T) {
}
}
+func TestViewClone(t *testing.T) {
+ const (
+ originalSize = 90
+ bytesToDelete = 30
+ )
+ var v View
+ v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize))
+
+ clonedV := v.Clone()
+ v.TrimFront(bytesToDelete)
+ if got, want := int(v.Size()), originalSize-bytesToDelete; got != want {
+ t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
+ }
+ if got := clonedV.Size(); got != originalSize {
+ t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
+ }
+}
+
func TestViewPullUp(t *testing.T) {
for _, tc := range []struct {
desc string
diff --git a/pkg/eventfd/BUILD b/pkg/eventfd/BUILD
new file mode 100644
index 000000000..02407cb99
--- /dev/null
+++ b/pkg/eventfd/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = [
+ "eventfd.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/hostarch",
+ "//pkg/tcpip/link/rawfile",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+)
diff --git a/pkg/eventfd/eventfd.go b/pkg/eventfd/eventfd.go
new file mode 100644
index 000000000..acdac01b8
--- /dev/null
+++ b/pkg/eventfd/eventfd.go
@@ -0,0 +1,115 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd wraps Linux's eventfd(2) syscall.
+package eventfd
+
+import (
+ "fmt"
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const sizeofUint64 = 8
+
+// Eventfd represents a Linux eventfd object.
+type Eventfd struct {
+ fd int
+}
+
+// Create returns an initialized eventfd.
+func Create() (Eventfd, error) {
+ fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
+ if err != 0 {
+ return Eventfd{}, fmt.Errorf("failed to create eventfd: %v", error(err))
+ }
+ if err := unix.SetNonblock(int(fd), true); err != nil {
+ unix.Close(int(fd))
+ return Eventfd{}, err
+ }
+ return Eventfd{int(fd)}, nil
+}
+
+// Wrap returns an initialized Eventfd using the provided fd.
+func Wrap(fd int) Eventfd {
+ return Eventfd{fd}
+}
+
+// Close closes the eventfd, after which it should not be used.
+func (ev Eventfd) Close() error {
+ return unix.Close(ev.fd)
+}
+
+// Dup copies the eventfd, calling dup(2) on the underlying file descriptor.
+func (ev Eventfd) Dup() (Eventfd, error) {
+ other, err := unix.Dup(ev.fd)
+ if err != nil {
+ return Eventfd{}, fmt.Errorf("failed to dup: %v", other)
+ }
+ return Eventfd{other}, nil
+}
+
+// Notify alerts other users of the eventfd. Users can receive alerts by
+// calling Wait or Read.
+func (ev Eventfd) Notify() error {
+ return ev.Write(1)
+}
+
+// Write writes a specific value to the eventfd.
+func (ev Eventfd) Write(val uint64) error {
+ var buf [sizeofUint64]byte
+ hostarch.ByteOrder.PutUint64(buf[:], val)
+ for {
+ n, err := unix.Write(ev.fd, buf[:])
+ if err == unix.EINTR {
+ continue
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short write to eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return err
+ }
+}
+
+// Wait blocks until eventfd is non-zero (i.e. someone calls Notify or Write).
+func (ev Eventfd) Wait() error {
+ _, err := ev.Read()
+ return err
+}
+
+// Read blocks until eventfd is non-zero (i.e. someone calls Notify or Write)
+// and returns the value read.
+func (ev Eventfd) Read() (uint64, error) {
+ var tmp [sizeofUint64]byte
+ n, err := rawfile.BlockingReadUntranslated(ev.fd, tmp[:])
+ if err != 0 {
+ return 0, err
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short read from eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return hostarch.ByteOrder.Uint64(tmp[:]), nil
+}
+
+// FD returns the underlying file descriptor. Use with care, as this breaks the
+// Eventfd abstraction.
+func (ev Eventfd) FD() int {
+ return ev.fd
+}
diff --git a/pkg/eventfd/eventfd_test.go b/pkg/eventfd/eventfd_test.go
new file mode 100644
index 000000000..96998d530
--- /dev/null
+++ b/pkg/eventfd/eventfd_test.go
@@ -0,0 +1,75 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+ "time"
+)
+
+func TestReadWrite(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // Make sure we can read actual values
+ const want = 343
+ if err := efd.Write(want); err != nil {
+ t.Fatalf("failed to write value: %d", want)
+ }
+
+ got, err := efd.Read()
+ if err != nil {
+ t.Fatalf("failed to read value: %v", err)
+ }
+ if got != want {
+ t.Fatalf("Read(): got %d, but wanted %d", got, want)
+ }
+}
+
+func TestWait(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // There's no way to test with certainty that Wait() blocks indefinitely, but
+ // as a best-effort we can wait a bit on it.
+ errCh := make(chan error)
+ go func() {
+ errCh <- efd.Wait()
+ }()
+ select {
+ case err := <-errCh:
+ t.Fatalf("Wait() returned without a call to Notify(): %v", err)
+ case <-time.After(500 * time.Millisecond):
+ }
+
+ // Notify and check that Wait() returned.
+ if err := efd.Notify(); err != nil {
+ t.Fatalf("Notify() failed: %v", err)
+ }
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Read() failed: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Read() did not return after Notify()")
+ }
+}
diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go
index b6e2012e8..38ce9be1e 100644
--- a/pkg/ring0/defs.go
+++ b/pkg/ring0/defs.go
@@ -77,6 +77,9 @@ type CPU struct {
// calls and exceptions via the Registers function.
registers arch.Registers
+ // floatingPointState holds floating point state.
+ floatingPointState fpu.State
+
// hooks are kernel hooks.
hooks Hooks
}
@@ -90,6 +93,15 @@ func (c *CPU) Registers() *arch.Registers {
return &c.registers
}
+// FloatingPointState returns the kernel floating point state.
+//
+// This is explicitly safe to call during KernelException and KernelSyscall.
+//
+//go:nosplit
+func (c *CPU) FloatingPointState() *fpu.State {
+ return &c.floatingPointState
+}
+
// SwitchOpts are passed to the Switch function.
type SwitchOpts struct {
// Registers are the user register state.
diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go
index 24f6e4cde..81e90dbf7 100644
--- a/pkg/ring0/defs_amd64.go
+++ b/pkg/ring0/defs_amd64.go
@@ -116,6 +116,11 @@ type CPUArchState struct {
errorType uintptr
*kernelEntry
+
+ // Copies of global variables, stored in CPU so that they can be used by
+ // syscall and exception handlers (in the upper address space).
+ hasXSAVE bool
+ hasXSAVEOPT bool
}
// ErrorCode returns the last error code.
diff --git a/pkg/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go
index afd646b0b..13ad4e4df 100644
--- a/pkg/ring0/entry_amd64.go
+++ b/pkg/ring0/entry_amd64.go
@@ -39,11 +39,6 @@ func sysenter()
// assembly to get the ABI0 (i.e., primary) address.
func addrOfSysenter() uintptr
-// swapgs swaps the current GS value.
-//
-// This must be called prior to sysret/iret.
-func swapgs()
-
// jumpToKernel jumps to the kernel version of the current RIP.
func jumpToKernel()
diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s
index 520bd9f57..d2913f190 100644
--- a/pkg/ring0/entry_amd64.s
+++ b/pkg/ring0/entry_amd64.s
@@ -142,8 +142,103 @@ TEXT ·jumpToUser(SB),NOSPLIT,$0
MOVQ AX, 0(SP)
RET
+// See kernel_amd64.go.
+//
+// The 16-byte frame size is for the saved values of MXCSR and the x87 control
+// word.
+TEXT ·doSwitchToUser(SB),NOSPLIT,$16-48
+ // We are passed pointers to heap objects, but do not store them in our
+ // local frame.
+ NO_LOCAL_POINTERS
+
+ // MXCSR and the x87 control word are the only floating point state
+ // that is callee-save and thus we must save.
+ STMXCSR mxcsr-0(SP)
+ FSTCW cw-8(SP)
+
+ // Restore application floating point state.
+ MOVQ cpu+0(FP), SI
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ TESTB BX, BX
+ JZ no_xrstor
+ // Use xrstor to restore all available fp state. For now, we restore
+ // everything unconditionally by setting the implicit operand edx:eax
+ // (the "requested feature bitmap") to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI)
+ JMP fprestore_done
+no_xrstor:
+ // Fall back to fxrstor if xsave is not available.
+ FXRSTOR64 0(DI)
+fprestore_done:
+
+ // Set application GS.
+ MOVQ regs+8(FP), R8
+ SWAP_GS()
+ MOVQ PTRACE_GS_BASE(R8), AX
+ PUSHQ AX
+ CALL ·writeGS(SB)
+ POPQ AX
+
+ // Call sysret() or iret().
+ MOVQ userCR3+24(FP), CX
+ MOVQ needIRET+32(FP), R9
+ ADDQ $-32, SP
+ MOVQ SI, 0(SP) // cpu
+ MOVQ R8, 8(SP) // regs
+ MOVQ CX, 16(SP) // userCR3
+ TESTQ R9, R9
+ JNZ do_iret
+ CALL ·sysret(SB)
+ JMP done_sysret_or_iret
+do_iret:
+ CALL ·iret(SB)
+done_sysret_or_iret:
+ MOVQ 24(SP), AX // vector
+ ADDQ $32, SP
+ MOVQ AX, vector+40(FP)
+
+ // Save application floating point state.
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ MOVB ·hasXSAVEOPT(SB), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
+ // Restore MXCSR and the x87 control word after one of the two floating
+ // point save cases above, to ensure the application versions are saved
+ // before being clobbered here.
+ LDMXCSR mxcsr-0(SP)
+
+ // FLDCW is a "waiting" x87 instruction, meaning it checks for pending
+ // unmasked exceptions before executing. Thus if userspace has unmasked
+ // an exception and has one pending, it can be raised by FLDCW even
+ // though the new control word will mask exceptions. To prevent this,
+ // we must first clear pending exceptions (which will be restored by
+ // XRSTOR, et al).
+ BYTE $0xDB; BYTE $0xE2; // FNCLEX
+ FLDCW cw-8(SP)
+
+ RET
+
// See entry_amd64.go.
-TEXT ·sysret(SB),NOSPLIT,$0-24
+TEXT ·sysret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -182,9 +277,11 @@ TEXT ·sysret(SB),NOSPLIT,$0-24
POPQ AX // Restore AX.
POPQ SP // Restore SP.
SYSRET64()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
-TEXT ·iret(SB),NOSPLIT,$0-24
+TEXT ·iret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -220,6 +317,8 @@ TEXT ·iret(SB),NOSPLIT,$0-24
WRITE_CR3() // Switch to userCR3.
POPQ AX // Restore AX.
IRET()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
@@ -324,11 +423,39 @@ kernel:
MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelSyscall(SB) // Call the trampoline.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
ADDR_OF_FUNC(·addrOfSysenter(SB), ·sysenter(SB));
@@ -416,15 +543,43 @@ kernel:
MOVQ 8(SP), BX // Load the error code.
MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
- MOVQ 0(SP), BX // BX contains the vector.
+
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
// Call the exception trampoline.
+ MOVQ 0(SP), BX // BX contains the vector.
LOAD_KERNEL_STACK(GS)
- PUSHQ BX // Second argument (vector).
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelException(SB) // Call the trampoline.
- POPQ BX // Pop vector.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ BX // Second argument (vector).
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelException(SB) // Call the trampoline.
+ POPQ BX // Pop vector.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
#define EXCEPTION_WITH_ERROR(value, symbol, addr) \
diff --git a/pkg/ring0/kernel.go b/pkg/ring0/kernel.go
index 292f9d0cc..e7dd84929 100644
--- a/pkg/ring0/kernel.go
+++ b/pkg/ring0/kernel.go
@@ -14,6 +14,10 @@
package ring0
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
+)
+
// Init initializes a new kernel.
//
//go:nosplit
@@ -80,6 +84,7 @@ func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
c.self = c // Set self reference.
c.kernel = k // Set kernel reference.
c.init(cpuID) // Perform architectural init.
+ c.floatingPointState = fpu.NewState()
// Require hooks.
if hooks != nil {
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 4a4c0ae26..7e55011b5 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -143,6 +143,9 @@ func (c *CPU) init(cpuID int) {
// Set mandatory flags.
c.registers.Eflags = KernelFlagsSet
+
+ c.hasXSAVE = hasXSAVE
+ c.hasXSAVEOPT = hasXSAVEOPT
}
// StackTop returns the kernel's stack address.
@@ -248,19 +251,21 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Ss = uint64(Udata) // Ditto.
// Perform the switch.
- swapgs() // GS will be swapped on return.
- WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point.
+ needIRET := uint64(0)
if switchOpts.FullRestore {
- vector = iret(c, regs, uintptr(userCR3))
- } else {
- vector = sysret(c, regs, uintptr(userCR3))
+ needIRET = 1
}
- SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
- RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
+ vector = doSwitchToUser(c, regs, switchOpts.FloatingPointState.BytePointer(), userCR3, needIRET) // escapes: no.
return
}
+func doSwitchToUser(
+ cpu *CPU, // +0(FP)
+ regs *arch.Registers, // +8(FP)
+ fpState *byte, // +16(FP)
+ userCR3 uint64, // +24(FP)
+ needIRET uint64) Vector // +32(FP), +40(FP)
+
var (
sentryXCR0 uintptr
sentryXCR0Once sync.Once
@@ -287,7 +292,7 @@ func initSentryXCR0() {
//go:nosplit
func startGo(c *CPU) {
// Save per-cpu.
- WriteGS(kernelAddr(c.kernelEntry))
+ writeGS(kernelAddr(c.kernelEntry))
//
// TODO(mpratt): Note that per the note above, this should be done
diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go
index 05c394ff5..c42a5b205 100644
--- a/pkg/ring0/lib_amd64.go
+++ b/pkg/ring0/lib_amd64.go
@@ -21,29 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
)
-// LoadFloatingPoint loads floating point state by the most efficient mechanism
-// available (set by Init).
-var LoadFloatingPoint func(*byte)
-
-// SaveFloatingPoint saves floating point state by the most efficient mechanism
-// available (set by Init).
-var SaveFloatingPoint func(*byte)
-
-// fxrstor uses fxrstor64 to load floating point state.
-func fxrstor(*byte)
-
-// xrstor uses xrstor to load floating point state.
-func xrstor(*byte)
-
-// fxsave uses fxsave64 to save floating point state.
-func fxsave(*byte)
-
-// xsave uses xsave to save floating point state.
-func xsave(*byte)
-
-// xsaveopt uses xsaveopt to save floating point state.
-func xsaveopt(*byte)
-
// writeFS sets the FS base address (selects one of wrfsbase or wrfsmsr).
func writeFS(addr uintptr)
@@ -53,8 +30,8 @@ func wrfsbase(addr uintptr)
// wrfsmsr writes to the GS_BASE MSR.
func wrfsmsr(addr uintptr)
-// WriteGS sets the GS address (set by init).
-var WriteGS func(addr uintptr)
+// writeGS sets the GS address (selects one of wrgsbase or wrgsmsr).
+func writeGS(addr uintptr)
// wrgsbase writes to the GS base address.
func wrgsbase(addr uintptr)
@@ -106,19 +83,4 @@ func Init(featureSet *cpuid.FeatureSet) {
hasXSAVE = featureSet.UseXsave()
hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase)
validXCR0Mask = uintptr(featureSet.ValidXCR0Mask())
- if hasXSAVEOPT {
- SaveFloatingPoint = xsaveopt
- LoadFloatingPoint = xrstor
- } else if hasXSAVE {
- SaveFloatingPoint = xsave
- LoadFloatingPoint = xrstor
- } else {
- SaveFloatingPoint = fxsave
- LoadFloatingPoint = fxrstor
- }
- if hasFSGSBASE {
- WriteGS = wrgsbase
- } else {
- WriteGS = wrgsmsr
- }
}
diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s
index 8ed98fc84..0f283aaae 100644
--- a/pkg/ring0/lib_amd64.s
+++ b/pkg/ring0/lib_amd64.s
@@ -128,6 +128,29 @@ TEXT ·wrfsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30;
RET
+// writeGS writes to the GS base.
+//
+// This is written in assembly because it must be callable from assembly (ABI0)
+// without an intermediate transition to ABIInternal.
+//
+// Preconditions: must be running in the lower address space, as it accesses
+// global data.
+TEXT ·writeGS(SB),NOSPLIT,$8-8
+ MOVQ addr+0(FP), AX
+
+ CMPB ·hasFSGSBASE(SB), $1
+ JNE msr
+
+ PUSHQ AX
+ CALL ·wrgsbase(SB)
+ POPQ AX
+ RET
+msr:
+ PUSHQ AX
+ CALL ·wrgsmsr(SB)
+ POPQ AX
+ RET
+
// wrgsbase writes to the GS base.
//
// The code corresponds to:
diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go
index 75f6218b3..38fe27c35 100644
--- a/pkg/ring0/offsets_amd64.go
+++ b/pkg/ring0/offsets_amd64.go
@@ -35,6 +35,9 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVE 0x%02x\n", reflect.ValueOf(&c.hasXSAVE).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVEOPT 0x%02x\n", reflect.ValueOf(&c.hasXSAVEOPT).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FPU_STATE 0x%02x\n", reflect.ValueOf(&c.floatingPointState).Pointer()-reflect.ValueOf(c).Pointer())
e := &kernelEntry{}
fmt.Fprintf(w, "\n// CPU entry offsets.\n")
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
index 0a045fc8e..2a1602e2b 100644
--- a/pkg/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -18,9 +18,9 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
- "//pkg/abi/linux",
"//pkg/errors",
"//pkg/errors/linuxerr",
+ "//pkg/sighandling",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index a9711e63d..0dd0aea83 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -23,6 +23,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// SegvError is returned when a safecopy function receives SIGSEGV.
@@ -132,10 +133,10 @@ func initializeAddresses() {
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) {
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index 2365b2c0d..15f84abea 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -20,7 +20,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
)
// maxRegisterSize is the maximum register size used in memcpy and memclr. It
@@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}
}
-
-// ReplaceSignalHandler replaces the existing signal handler for the provided
-// signal with the one that handles faults in safecopy-protected functions.
-//
-// It stores the value of the previously set handler in previous.
-//
-// This function will be called on initialization in order to install safecopy
-// handlers for appropriate signals. These handlers will call the previous
-// handler however, and if this is function is being used externally then the
-// same courtesy is expected.
-func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
- var sa linux.SigAction
- const maskLen = 8
-
- // Get the existing signal handler information, and save the current
- // handler. Once we replace it, we will use this pointer to fall back to
- // it when we receive other signals.
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
- return e
- }
-
- // Fail if there isn't a previous handler.
- if sa.Handler == 0 {
- return fmt.Errorf("previous handler for signal %x isn't set", sig)
- }
-
- *previous = uintptr(sa.Handler)
-
- // Install our own handler.
- sa.Handler = uint64(handler)
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
- return e
- }
-
- return nil
-}
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 4370cce33..d2eb03bb7 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -45,7 +45,8 @@ type pipeOperations struct {
fsutil.FileNoIoctl `state:"nosave"`
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.Queue `state:"nosave"`
+
+ waiter.Queue
// flags are the flags used to open the pipe.
flags fs.FileFlags `state:".(fs.FileFlags)"`
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 92d58e3e9..99c37291e 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -70,7 +70,7 @@ type inodeFileState struct {
descriptor *descriptor `state:"wait"`
// Event queue for blocking operations.
- queue waiter.Queue `state:"zerovalue"`
+ queue waiter.Queue
// sattr is used to restore the inodeOperations.
sattr fs.StableAttr `state:"wait"`
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index 51cd6cd37..941f37116 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -43,7 +43,7 @@ type Inotify struct {
// user, since we may aggressively reuse an id on S/R.
id uint64
- waiter.Queue `state:"nosave"`
+ waiter.Queue
// evMu *only* protects the events list. We need a separate lock because
// while queuing events, a watch needs to lock the event queue, and using mu
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 7d7a207cc..e39d340fe 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -132,7 +132,7 @@ type Locks struct {
locks LockSet
// blockedQueue is the queue of waiters that are waiting on a lock.
- blockedQueue waiter.Queue `state:"zerovalue"`
+ blockedQueue waiter.Queue
}
// Blocker is the interface used for blocking locks. Passing a nil Blocker
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index 085aa6d61..443b9a94c 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -109,6 +109,9 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
+ "msgmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNI, 10))),
+ "msgmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMAX, 10))),
+ "msgmnb": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNB, 10))),
}
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 1c8518d71..ca8be8683 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -43,7 +43,7 @@ type TimerOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- events waiter.Queue `state:"zerovalue"`
+ events waiter.Queue
timer *ktime.Timer
// val is the number of timer expirations since the last successful call to
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index f9fca6d8e..f2c9e9668 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -102,10 +102,10 @@ type lineDiscipline struct {
column int
// masterWaiter is used to wait on the master end of the TTY.
- masterWaiter waiter.Queue `state:"zerovalue"`
+ masterWaiter waiter.Queue
// replicaWaiter is used to wait on the replica end of the TTY.
- replicaWaiter waiter.Queue `state:"zerovalue"`
+ replicaWaiter waiter.Queue
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 7bef8242f..b98825e26 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1595,7 +1595,10 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats
// (b/148380782). Allow all other extended attributes to be passed through
// to the remote filesystem. This is inconsistent with Linux's 9p client,
// but consistent with other filesystems (e.g. FUSE).
- if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) {
+ //
+ // NOTE(b/202533394): Also disallow "trusted" namespace for now. This is
+ // consistent with the VFS1 gofer client.
+ if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) || strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
return linuxerr.EOPNOTSUPP
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
@@ -2046,16 +2049,7 @@ func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) {
}
if d.fs.opts.lisaEnabled {
- xattrs, err := d.controlFDLisa.ListXattr(ctx, size)
- if err != nil {
- return nil, err
- }
-
- res := make([]string, 0, len(xattrs))
- for _, xattr := range xattrs {
- res = append(res, xattr)
- }
- return res, nil
+ return d.controlFDLisa.ListXattr(ctx, size)
}
xattrMap, err := d.file.listXattr(ctx, size)
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 26d44744b..7b0be9c14 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -268,6 +268,6 @@ func cpuInfoData(k *kernel.Kernel) string {
return buf.String()
}
-func shmData(v uint64) dynamicInode {
+func ipcData(v uint64) dynamicInode {
return newStaticFile(strconv.FormatUint(v, 10))
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 99f64a9d8..82e2857b3 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -47,9 +47,12 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
"sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
- "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
+ "shmall": fs.newInode(ctx, root, 0444, ipcData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMNI)),
+ "msgmni": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNI)),
+ "msgmax": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMAX)),
+ "msgmnb": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNB)),
"yama": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"ptrace_scope": fs.newYAMAPtraceScopeFile(ctx, k, root),
}),
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 66fa1ad40..03c8e2f38 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -12,8 +12,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/fd",
- "//pkg/hostarch",
+ "//pkg/eventfd",
"//pkg/log",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
index 285ea9050..5df06a60f 100644
--- a/pkg/sentry/hostmm/hostmm.go
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -21,9 +21,7 @@ import (
"os"
"path"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/fd"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
)
@@ -54,7 +52,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
defer eventControlFile.Close()
- eventFD, err := newEventFD()
+ eventFD, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -75,20 +73,11 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
const stopVal = 1 << 63
stopCh := make(chan struct{})
go func() { // S/R-SAFE: f provides synchronization if necessary
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
for {
- n, err := rw.Read(buf[:])
+ val, err := eventFD.Read()
if err != nil {
- if err == unix.EINTR {
- continue
- }
panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err))
}
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- val := hostarch.ByteOrder.Uint64(buf[:])
if val >= stopVal {
// Assume this was due to the notifier's "destructor" (the
// function returned by NotifyCurrentMemcgPressureCallback
@@ -101,30 +90,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
}()
return func() {
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
- hostarch.ByteOrder.PutUint64(buf[:], stopVal)
- for {
- n, err := rw.Write(buf[:])
- if err != nil {
- if err == unix.EINTR {
- continue
- }
- panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err))
- }
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- break
- }
+ eventFD.Write(stopVal)
<-stopCh
}, nil
}
-
-func newEventFD() (*fd.FD, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return nil, fmt.Errorf("failed to create eventfd: %v", e)
- }
- return fd.New(int(f)), nil
-}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c0f13bf52..53a21e1e2 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -255,7 +255,6 @@ go_library(
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/epoll",
"//pkg/sentry/kernel/futex",
"//pkg/sentry/kernel/msgqueue",
"//pkg/sentry/kernel/sched",
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 6006c46a9..8d0a21baf 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -66,7 +66,7 @@ type pollEntry struct {
file *refs.WeakRef `state:"manual"`
id FileIdentifier `state:"wait"`
userData [2]int32
- waiter waiter.Entry `state:"manual"`
+ waiter waiter.Entry
mask waiter.EventMask
flags EntryFlags
@@ -102,7 +102,7 @@ type EventPoll struct {
// Wait queue is used to notify interested parties when the event poll
// object itself becomes readable or writable.
- waiter.Queue `state:"zerovalue"`
+ waiter.Queue
// files is the map of all the files currently being observed, it is
// protected by mu.
@@ -454,14 +454,3 @@ func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error {
return nil
}
-
-// UnregisterEpollWaiters removes the epoll waiter objects from the waiting
-// queues. This is different from Release() as the file is not dereferenced.
-func (e *EventPoll) UnregisterEpollWaiters() {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- for _, entry := range e.files {
- entry.id.File.EventUnregister(&entry.waiter)
- }
-}
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index e08d6287f..135a6d72c 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -21,9 +21,7 @@ import (
// afterLoad is invoked by stateify.
func (p *pollEntry) afterLoad() {
- p.waiter.Callback = p
p.file = refs.NewWeakRef(p.id.File, p)
- p.id.File.EventRegister(&p.waiter, p.mask)
}
// afterLoad is invoked by stateify.
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 5ea44a2c2..bf625dede 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -54,7 +54,7 @@ type EventOperations struct {
// Queue is used to notify interested parties when the event object
// becomes readable or writable.
- wq waiter.Queue `state:"zerovalue"`
+ wq waiter.Queue
// val is the current value of the event counter.
val uint64
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index df5160b67..5dc821a48 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -57,7 +57,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -78,11 +77,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
-// easy access everywhere. To be removed once VFS2 becomes the default.
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global to allow
+// easy access everywhere.
+//
+// TODO(gvisor.dev/issue/1624): Remove when VFS1 is no longer used.
var VFS2Enabled = false
-// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// LISAFSEnabled is set to true when lisafs protocol is enabled. Added as a
+// global to allow easy access everywhere.
+//
+// TODO(gvisor.dev/issue/6319): Remove when lisafs is default.
+var LISAFSEnabled = false
+
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global to allow
// easy access everywhere. To be removed once FUSE is completed.
var FUSEEnabled = false
@@ -478,11 +485,6 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
return err
}
- // Remove all epoll waiter objects from underlying wait queues.
- // NOTE: for programs to resume execution in future snapshot scenarios,
- // we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters(ctx)
-
// Clear the dirent cache before saving because Dirents must be Loaded in a
// particular order (parents before children), and Loading dirents from a cache
// breaks that order.
@@ -615,32 +617,6 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
})
}
-// Preconditions: !VFS2Enabled.
-func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
- ts.mu.RLock()
- defer ts.mu.RUnlock()
-
- // Tasks that belong to the same process could potentially point to the
- // same FDTable. So we retain a map of processed ones to avoid
- // processing the same FDTable multiple times.
- processed := make(map[*FDTable]struct{})
- for t := range ts.Root.tids {
- // We can skip locking Task.mu here since the kernel is paused.
- if t.fdTable == nil {
- continue
- }
- if _, ok := processed[t.fdTable]; ok {
- continue
- }
- t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
- e.UnregisterEpollWaiters()
- }
- })
- processed[t.fdTable] = struct{}{}
- }
-}
-
// Preconditions: The kernel must be paused.
func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
invalidated := make(map[*mm.MemoryManager]struct{})
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 86beee6fe..8345473f3 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -55,7 +55,7 @@ const (
//
// +stateify savable
type Pipe struct {
- waiter.Queue `state:"nosave"`
+ waiter.Queue
// isNamed indicates whether this is a named pipe.
//
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index b0004482c..1ea3c1bf7 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -158,7 +158,7 @@ type Task struct {
// signalQueue is protected by the signalMutex. Note that the task does
// not implement all queue methods, specifically the readiness checks.
// The task only broadcast a notification on signal delivery.
- signalQueue waiter.Queue `state:"zerovalue"`
+ signalQueue waiter.Queue
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index c5b099559..f0c168ecc 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -191,9 +191,11 @@ const (
//
// Preconditions: The task's owning TaskSet.mu must be locked.
func (t *Task) updateInfoLocked() {
- // Use the task's TID in the root PID namespace for logging.
+ // Use the task's TID and PID in the root PID namespace for logging.
+ pid := t.tg.pidns.owner.Root.tgids[t.tg]
tid := t.tg.pidns.owner.Root.tids[t]
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid))
+
t.rebuildTraceContext(tid)
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 77ad62445..e38b723ce 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -324,11 +324,7 @@ type threadGroupNode struct {
// eventQueue is notified whenever a event of interest to Task.Wait occurs
// in a child of this thread group, or a ptrace tracee of a task in this
// thread group. Events are defined in task_exit.go.
- //
- // Note that we cannot check and save this wait queue similarly to other
- // wait queues, as the queue will not be empty by the time of saving, due
- // to the wait sourced from Exec().
- eventQueue waiter.Queue `state:"nosave"`
+ eventQueue waiter.Queue
// leader is the thread group's leader, which is the oldest task in the
// thread group; usually the last task in the thread group to call
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index a26f54269..834d72408 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -63,7 +63,6 @@ go_library(
"//pkg/procid",
"//pkg/ring0",
"//pkg/ring0/pagetables",
- "//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
@@ -71,6 +70,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/time",
+ "//pkg/sighandling",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 826997e77..5be2215ed 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -19,8 +19,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/ring0"
- "gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// bluepill enters guest mode.
@@ -97,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 0567c8d32..b2db2bb9f 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -71,10 +71,6 @@ func (c *vCPU) KernelSyscall() {
if regs.Rax != ^uint64(0) {
regs.Rip -= 2 // Rewind.
}
- // We only trigger a bluepill entry in the bluepill function, and can
- // therefore be guaranteed that there is no floating point state to be
- // loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
// N.B. Since KernelSyscall is called when the kernel makes a syscall,
// FS_BASE is already set for correct execution of this function.
//
@@ -112,8 +108,6 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
- // See above.
ring0.HaltAndWriteFSBase(regs) // escapes: no, reload host segment.
}
@@ -144,5 +138,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// Set the context pointer to the saved floating point state. This is
// where the guest data has been serialized, the kernel will restore
// from this new pointer value.
- context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
+ context.Fpstate = uint64(uintptrValue(c.FloatingPointState().BytePointer())) // escapes: no.
}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index acb0cb05f..df772d620 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -70,7 +70,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
context.Fpsimd64.Fpsr = fpsimd.Fpsr
context.Fpsimd64.Fpcr = fpsimd.Fpcr
context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -90,12 +90,12 @@ func (c *vCPU) KernelSyscall() {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
@@ -114,12 +114,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index aac0fdffe..ad6863646 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -77,7 +77,11 @@ var (
// OpenDevice opens the KVM device at /dev/kvm and returns the File.
func OpenDevice() (*os.File, error) {
- f, err := os.OpenFile("/dev/kvm", unix.O_RDWR, 0)
+ dev, ok := os.LookupEnv("GVISOR_KVM_DEV")
+ if !ok {
+ dev = "/dev/kvm"
+ }
+ f, err := os.OpenFile(dev, unix.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("error opening /dev/kvm: %v", err)
}
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index dcf34015d..f1f7e4ea4 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -28,9 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/seccomp"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sighandling"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -723,7 +723,7 @@ func addrOfSigsysHandler() uintptr
func seccompMmapRules(m *machine) {
seccompMmapRulesOnce.Do(func() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
rules := []seccomp.RuleSet{}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index ab1e036b7..5bc023899 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -72,10 +71,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_amd64.go.
- floatingPointState fpu.State
}
const (
@@ -152,12 +147,6 @@ func (c *vCPU) initArchState() error {
return fmt.Errorf("error setting user registers: %v", errno)
}
- // Allocate some floating point state save area for the local vCPU.
- // This will be saved prior to leaving the guest, and we restore from
- // this always. We cannot use the pointer in the context alone because
- // we don't know how large the area there is in reality.
- c.floatingPointState = fpu.NewState()
-
// Set the time offset to the host native time.
return c.setSystemTime()
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 08d98c479..31998a600 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -40,10 +39,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_arm64.go.
- floatingPointState fpu.State
}
const (
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 7e8e19dcb..e73d5c544 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -159,8 +158,6 @@ func (c *vCPU) initArchState() error {
c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
}
- c.floatingPointState = fpu.NewState()
-
return c.setSystemTime()
}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 7ee89a735..00f925166 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "socket",
- srcs = ["socket.go"],
+ srcs = [
+ "socket.go",
+ "socket_state.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index f9a5b0df1..6077b2150 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -29,10 +29,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "time"
)
-const maxInt = int(^uint(0) >> 1)
-
// SCMCredentials represents a SCM_CREDENTIALS socket control message.
type SCMCredentials interface {
transport.CredentialsControlMessage
@@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
}
// Files implements SCMRights.Files.
-func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) {
n := max
var trunc bool
if l := len(*fs); n > l {
@@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
break
}
- fds = append(fds, int32(fd))
+ fds = append(fds, fd)
}
return fds, trunc
}
@@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte {
}
// PackTimestamp packs a SO_TIMESTAMP socket control message.
-func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
- timestampP := linux.NsecToTimeval(timestamp)
+func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte {
+ timestampP := linux.NsecToTimeval(timestamp.UnixNano())
return putCmsgStruct(
buf,
linux.SOL_SOCKET,
@@ -545,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var ts linux.Timeval
ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
- cmsgs.IP.Timestamp = ts.ToNsecCapped()
+ cmsgs.IP.Timestamp = ts.ToTime()
cmsgs.IP.HasTimestamp = true
i += bits.AlignUp(length, width)
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
index 7e28a0cef..1b04e1bbc 100644
--- a/pkg/sentry/socket/control/control_test.go
+++ b/pkg/sentry/socket/control/control_test.go
@@ -50,7 +50,7 @@ func TestParse(t *testing.T) {
want := socket.ControlMessages{
IP: socket.IPControlMessages{
HasTimestamp: true,
- Timestamp: ts.ToNsecCapped(),
+ Timestamp: ts.ToTime(),
},
}
if diff := cmp.Diff(want, cmsg); diff != "" {
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 1c1e501ba..6e2318f75 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -111,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
}))
- return int64(n), err
+ return n, err
}
// Write implements fs.FileOperations.Write.
@@ -134,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
}))
- return int64(n), err
+ return n, err
}
// Socket implements socket.Provider.Socket.
@@ -180,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
@@ -207,7 +207,7 @@ type socketOpsCommon struct {
// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(context.Context) {
fdnotifier.RemoveFD(int32(s.fd))
- unix.Close(s.fd)
+ _ = unix.Close(s.fd)
}
// Readiness implements waiter.Waitable.Readiness.
@@ -218,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
// EventRegister implements waiter.Waitable.EventRegister.
func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// Connect implements socket.Socket.Connect.
@@ -316,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
if kernel.VFS2Enabled {
f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK))
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -328,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
} else {
f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0)
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -343,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
}
// Bind implements socket.Socket.Bind.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -356,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Listen implements socket.Socket.Listen.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.FromError(unix.Listen(s.fd, backlog))
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
switch how {
case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR:
return syserr.FromError(unix.Shutdown(s.fd, how))
@@ -371,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -401,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
case linux.TCP_NODELAY:
optlen = sizeofInt32
case linux.TCP_INFO:
- optlen = int(linux.SizeOfTCPInfo)
+ optlen = linux.SizeOfTCPInfo
}
}
@@ -579,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
- controlMessages.IP.Timestamp = ts.ToNsecCapped()
+ controlMessages.IP.Timestamp = ts.ToTime()
}
case linux.SOL_IP:
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 0f6e576a9..b9c15daab 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID {
}
// Action implements stack.Target.Action.
-func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index bf5ec4558..075f61cda 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"device.go",
"netstack.go",
+ "netstack_state.go",
"netstack_vfs2.go",
"provider.go",
"provider_vfs2.go",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index dedc32dda..030c6c8e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -274,6 +274,7 @@ var Metrics = tcpip.Stats{
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."),
+ SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -378,9 +379,9 @@ type socketOpsCommon struct {
// timestampValid indicates whether timestamp for SIOCGSTAMP has been
// set. It is protected by readMu.
timestampValid bool
- // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // timestamp holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
- timestampNS int64
+ timestamp time.Time `state:".(int64)"`
// TODO(b/153685824): Move this to SocketOptions.
// sockOptInq corresponds to TCP_INQ.
@@ -410,15 +411,6 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
-// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
-// netstack representation taking any addresses into account.
-func bytesToIPAddress(addr []byte) tcpip.Address {
- if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
- return ""
- }
- return tcpip.Address(addr)
-}
-
// minSockAddrLen returns the minimum length in bytes of a socket address for
// the socket's family.
func (s *socketOpsCommon) minSockAddrLen() int {
@@ -468,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
t := kernel.TaskFromContext(ctx)
start := t.Kernel().MonotonicClock().Now()
deadline := start.Add(v.Timeout)
- t.BlockWithDeadline(ch, true, deadline)
+ _ = t.BlockWithDeadline(ch, true, deadline)
}
}
@@ -488,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// WriteTo implements fs.FileOperations.WriteTo.
-func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
@@ -543,7 +535,7 @@ func (l *limitedPayloader) Len() int {
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := limitedPayloader{
inner: io.LimitedReader{
R: r,
@@ -654,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < 2 {
return syserr.ErrInvalidArgument
}
@@ -714,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Listen implements the linux syscall listen(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
}
@@ -805,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
f, err := ConvertShutdown(how)
if err != nil {
return err
@@ -886,7 +878,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1402,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true)
if err != nil {
return nil, err
}
@@ -1422,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1442,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
@@ -1459,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1599,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false)
if err != nil {
return nil, err
}
@@ -1619,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1639,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
@@ -2186,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
@@ -2429,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
@@ -2601,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2613,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2774,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
- s.timestampNS = cm.Timestamp
+ s.timestamp = cm.Timestamp
}
}
@@ -2833,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int,
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
if flags&linux.MSG_ERRQUEUE != 0 {
return s.recvErr(t, dst)
}
@@ -2998,7 +2990,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, linuxerr.ENOENT
}
- tv := linux.NsecToTimeval(s.timestampNS)
+ tv := linux.NsecToTimeval(s.timestamp.UnixNano())
_, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
@@ -3105,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// interfaceIoctl implements interface requests.
-func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
var (
iface inet.Interface
index int32
@@ -3113,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
)
// Find the relevant device.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice
}
@@ -3124,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4]))
- if iface, ok := stack.Interfaces()[index]; ok {
+ if iface, ok := stk.Interfaces()[index]; ok {
ifr.SetName(iface.Name)
return nil
}
@@ -3132,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// Find the relevant device.
- for index, iface = range stack.Interfaces() {
+ for index, iface = range stk.Interfaces() {
if iface.Name == ifr.Name() {
found = true
break
@@ -3165,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
case linux.SIOCGIFFLAGS:
- f, err := interfaceStatusFlags(stack, iface.Name)
+ f, err := interfaceStatusFlags(stk, iface.Name)
if err != nil {
return err
}
@@ -3175,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3211,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3243,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice.ToError()
}
if ifc.Ptr == 0 {
- ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq)
return nil
}
max := ifc.Len
ifc.Len = 0
- for key, ifaceAddrs := range stack.InterfaceAddrs() {
- iface := stack.Interfaces()[key]
+ for key, ifaceAddrs := range stk.InterfaceAddrs() {
+ iface := stk.Interfaces()[key]
for _, ifaceAddr := range ifaceAddrs {
// Don't write past the end of the buffer.
if ifc.Len+int32(linux.SizeOfIFReq) > max {
diff --git a/pkg/sentry/socket/netstack/netstack_state.go b/pkg/sentry/socket/netstack/netstack_state.go
new file mode 100644
index 000000000..591e00d42
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_state.go
@@ -0,0 +1,31 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "time"
+)
+
+func (s *socketOpsCommon) saveTimestamp() int64 {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ return s.timestamp.UnixNano()
+}
+
+func (s *socketOpsCommon) loadTimestamp(nsec int64) {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 2f0eb4a6c..d4b80a39d 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"sync/atomic"
+ "time"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -51,8 +52,8 @@ type ControlMessages struct {
func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
var p linux.ControlMessageIPPacketInfo
p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ copy(p.LocalAddr[:], packetInfo.LocalAddr)
+ copy(p.DestinationAddr[:], packetInfo.DestinationAddr)
return p
}
@@ -60,7 +61,7 @@ func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPack
// format.
func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo {
var p linux.ControlMessageIPv6PacketInfo
- if n := copy(p.Addr[:], []byte(packetInfo.Addr)); n != len(p.Addr) {
+ if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) {
panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr)))
}
p.NIC = uint32(packetInfo.NIC)
@@ -156,9 +157,9 @@ type IPControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go
new file mode 100644
index 000000000..32e12b238
--- /dev/null
+++ b/pkg/sentry/socket/socket_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package socket
+
+import (
+ "time"
+)
+
+func (i *IPControlMessages) saveTimestamp() int64 {
+ return i.Timestamp.UnixNano()
+}
+
+func (i *IPControlMessages) loadTimestamp(nsec int64) {
+ i.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 757ff2a40..4d3f4d556 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output []
if err == nil {
// Fill in the output after successful execution.
i.post(t, args, retval, output, LogMaximumSize)
- rval = fmt.Sprintf("%#x (%v)", retval, elapsed)
+ rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed)
} else {
- rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed)
+ rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed)
}
switch len(output) {
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 04bc4d10c..fefd0fc9c 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -135,12 +135,16 @@ func (ep *EpollInstance) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
ep.mu.Lock()
- for epi := ep.ready.Front(); epi != nil; epi = epi.Next() {
+ var next *epollInterest
+ for epi := ep.ready.Front(); epi != nil; epi = next {
+ next = epi.Next()
wmask := waiter.EventMaskFromLinux(epi.mask)
if epi.key.file.Readiness(wmask)&wmask != 0 {
ep.mu.Unlock()
return waiter.ReadableEvents
}
+ ep.ready.Remove(epi)
+ epi.ready = false
}
ep.mu.Unlock()
return 0
diff --git a/pkg/shim/service.go b/pkg/shim/service.go
index 24e3b7a82..0980d964e 100644
--- a/pkg/shim/service.go
+++ b/pkg/shim/service.go
@@ -77,6 +77,8 @@ const (
// shimAddressPath is the relative path to a file that contains the address
// to the shim UDS. See service.shimAddress.
shimAddressPath = "address"
+
+ cgroupParentAnnotation = "dev.gvisor.spec.cgroup-parent"
)
// New returns a new shim service that can be used via GRPC.
@@ -952,7 +954,7 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
if err != nil {
return nil, fmt.Errorf("update volume annotations: %w", err)
}
- updated = updateCgroup(spec) || updated
+ updated = setPodCgroup(spec) || updated
if updated {
if err := utils.WriteSpec(r.Bundle, spec); err != nil {
@@ -980,12 +982,13 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
return p, nil
}
-// updateCgroup updates cgroup path for the sandbox to make the sandbox join the
-// pod cgroup and not the pause container cgroup. Returns true if the spec was
-// modified. Ex.:
-// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123
+// setPodCgroup searches for the pod cgroup path inside the container's cgroup
+// path. If found, it's set as an annotation in the spec. This is done so that
+// the sandbox joins the pod cgroup. Otherwise, the sandbox would join the pause
+// container cgroup. Returns true if the spec was modified. Ex.:
+// /kubepods/burstable/pod123/container123 => kubepods/burstable/pod123
//
-func updateCgroup(spec *specs.Spec) bool {
+func setPodCgroup(spec *specs.Spec) bool {
if !utils.IsSandbox(spec) {
return false
}
@@ -1009,7 +1012,10 @@ func updateCgroup(spec *specs.Spec) bool {
if spec.Linux.CgroupsPath == path {
return false
}
- spec.Linux.CgroupsPath = path
+ if spec.Annotations == nil {
+ spec.Annotations = make(map[string]string)
+ }
+ spec.Annotations[cgroupParentAnnotation] = path
return true
}
}
diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go
index 2d9f07e02..4b4410a58 100644
--- a/pkg/shim/service_test.go
+++ b/pkg/shim/service_test.go
@@ -40,12 +40,12 @@ func TestCgroupPath(t *testing.T) {
{
name: "no-container",
path: "foo/pod123",
- want: "foo/pod123",
+ want: "",
},
{
name: "no-container-absolute",
path: "/foo/pod123",
- want: "/foo/pod123",
+ want: "",
},
{
name: "double-pod",
@@ -70,7 +70,7 @@ func TestCgroupPath(t *testing.T) {
{
name: "no-pod",
path: "/foo/nopod123/container",
- want: "/foo/nopod123/container",
+ want: "",
},
} {
t.Run(tc.name, func(t *testing.T) {
@@ -79,12 +79,12 @@ func TestCgroupPath(t *testing.T) {
CgroupsPath: tc.path,
},
}
- updated := updateCgroup(&spec)
- if spec.Linux.CgroupsPath != tc.want {
- t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath)
+ updated := setPodCgroup(&spec)
+ if got := spec.Annotations[cgroupParentAnnotation]; got != tc.want {
+ t.Errorf("setPodCgroup(%q), want: %q, got: %q", tc.path, tc.want, got)
}
- if shouldUpdate := tc.path != tc.want; shouldUpdate != updated {
- t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate)
+ if shouldUpdate := len(tc.want) > 0; shouldUpdate != updated {
+ t.Errorf("setPodCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate)
}
})
}
@@ -113,8 +113,8 @@ func TestCgroupNoUpdate(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
- if updated := updateCgroup(tc.spec); updated {
- t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated)
+ if updated := setPodCgroup(tc.spec); updated {
+ t.Errorf("setPodCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated)
}
})
}
diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD
index 1790d57c9..72f10f982 100644
--- a/pkg/sentry/sighandling/BUILD
+++ b/pkg/sighandling/BUILD
@@ -8,7 +8,7 @@ go_library(
"sighandling.go",
"sighandling_unsafe.go",
],
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go
index bdaf8af29..bdaf8af29 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sighandling/sighandling.go
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go
index 3fe5c6770..7deeda042 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sighandling/sighandling_unsafe.go
@@ -15,6 +15,7 @@
package sighandling
import (
+ "fmt"
"unsafe"
"golang.org/x/sys/unix"
@@ -37,3 +38,36 @@ func IgnoreChildStop() error {
return nil
}
+
+// ReplaceSignalHandler replaces the existing signal handler for the provided
+// signal with the function pointer at `handler`. This bypasses the Go runtime
+// signal handlers, and should only be used for low-level signal handlers where
+// use of signal.Notify is not appropriate.
+//
+// It stores the value of the previously set handler in previous.
+func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
+ var sa linux.SigAction
+ const maskLen = 8
+
+ // Get the existing signal handler information, and save the current
+ // handler. Once we replace it, we will use this pointer to fall back to
+ // it when we receive other signals.
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ // Fail if there isn't a previous handler.
+ if sa.Handler == 0 {
+ return fmt.Errorf("previous handler for signal %x isn't set", sig)
+ }
+
+ *previous = uintptr(sa.Handler)
+
+ // Install our own handler.
+ sa.Handler = uint64(handler)
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 73791b456..517f16329 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -26,6 +26,7 @@ go_library(
"rwmutex_unsafe.go",
"seqcount.go",
"sync.go",
+ "wait.go",
],
marshal = False,
stateify = False,
diff --git a/pkg/sync/wait.go b/pkg/sync/wait.go
new file mode 100644
index 000000000..f8e7742a5
--- /dev/null
+++ b/pkg/sync/wait.go
@@ -0,0 +1,58 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sync
+
+// WaitGroupErr is similar to WaitGroup but allows goroutines to report error.
+// Only the first error is retained and reported back.
+//
+// Example usage:
+// wg := WaitGroupErr{}
+// wg.Add(1)
+// go func() {
+// defer wg.Done()
+// if err := ...; err != nil {
+// wg.ReportError(err)
+// return
+// }
+// }()
+// return wg.Error()
+//
+type WaitGroupErr struct {
+ WaitGroup
+
+ // mu protects firstErr.
+ mu Mutex
+
+ // firstErr holds the first error reported. nil is no error occurred.
+ firstErr error
+}
+
+// ReportError reports an error. Note it does not call Done().
+func (w *WaitGroupErr) ReportError(err error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.firstErr == nil {
+ w.firstErr = err
+ }
+}
+
+// Error waits for the counter to reach 0 and returns the first reported error
+// if any.
+func (w *WaitGroupErr) Error() error {
+ w.Wait()
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.firstErr
+}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index dbe4506cc..b98de54c5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -25,6 +25,7 @@ go_library(
"stdclock.go",
"stdclock_state.go",
"tcpip.go",
+ "tcpip_state.go",
"timer.go",
],
visibility = ["//visibility:public"],
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 87a0b9a62..e53789d92 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -152,10 +152,22 @@ type PollEvent struct {
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
+ n, err := BlockingReadUntranslated(fd, b)
+ if err != 0 {
+ return n, TranslateErrno(err)
+ }
+ return n, nil
+}
+
+// BlockingReadUntranslated reads from a file descriptor that is set up as
+// non-blocking. If no data is available, it will block in a poll() syscall
+// until the file descriptor becomes readable. It returns the raw unix.Errno
+// value returned by the underlying syscalls.
+func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
- return int(n), nil
+ return int(n), 0
}
event := PollEvent{
@@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
_, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
- return 0, TranslateErrno(e)
+ return 0, e
}
}
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..f8076d83c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,26 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
+ "//pkg/eventfd",
"//pkg/log",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +33,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +46,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..b12647fdd
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,199 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var eventFD eventfd.Eventfd
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, err := eventfd.Create()
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err)
+ }
+ dataFD, err = createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ c.EventFD.Close()
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..87747dcc7 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -21,7 +21,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -30,7 +30,7 @@ type rx struct {
data []byte
sharedData []byte
q queue.Rx
- eventFD int
+ eventFD eventfd.Eventfd
}
// init initializes all state needed by the rx queue based on the information
@@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
// Duplicate the eventFD so that caller can close it but we can still
// use it.
- efd, err := unix.Dup(c.EventFD)
+ efd, err := c.EventFD.Dup()
if err != nil {
unix.Munmap(txPipe)
unix.Munmap(rxPipe)
@@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
return err
}
- // Set the eventfd as non-blocking.
- if err := unix.SetNonblock(efd, true); err != nil {
- unix.Munmap(txPipe)
- unix.Munmap(rxPipe)
- unix.Munmap(data)
- unix.Munmap(sharedData)
- unix.Close(efd)
- return err
- }
-
// Initialize state based on buffers.
r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
r.data = data
@@ -105,7 +95,13 @@ func (r *rx) cleanup() {
unix.Munmap(r.data)
unix.Munmap(r.sharedData)
- unix.Close(r.eventFD)
+ r.eventFD.Close()
+}
+
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ r.eventFD.Notify()
}
// postAndReceive posts the provided buffers (if any), and then tries to read
@@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
if len(b) != 0 && !r.q.PostBuffers(b) {
r.q.EnableNotification()
for !r.q.PostBuffers(b) {
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
@@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
}
// Wait for notification.
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..6ea21ffd1
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,142 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ s.eventFD.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..13a82903f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,175 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ s.eventFD.Notify()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 66efe6472..bcb37a465 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,14 +24,16 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,7 +49,7 @@ type QueueConfig struct {
// EventFD is a file descriptor for the event that is signaled when
// data is becomes available in this queue.
- EventFD int
+ EventFD eventfd.Eventfd
// TxPipeFD is a file descriptor for the tx pipe associated with the
// queue.
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: eventfd.Wrap(fds[1]),
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -119,13 +223,13 @@ func (e *endpoint) Close() {
// Tell dispatch goroutine to stop, then write to the eventfd so that
// it wakes up in case it's sleeping.
atomic.StoreUint32(&e.stopRequested, 1)
- unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ e.rx.eventFD.Notify()
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WriteRawPacket implements stack.LinkEndpoint.
func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..ccc84989d
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,333 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ e.rx.eventFD.Notify()
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..66ffc33b8 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
@@ -672,7 +619,7 @@ func TestSimpleReceive(t *testing.T) {
// Push completion.
c.pushRxCompletion(uint32(len(contents)), bufs)
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
@@ -718,7 +665,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
@@ -734,7 +681,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for them to be reposted.
for j := 0; j < 2; j++ {
@@ -759,7 +706,7 @@ func TestReceivePostingIsFull(t *testing.T) {
first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that packet is received.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
@@ -768,7 +715,7 @@ func TestReceivePostingIsFull(t *testing.T) {
second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that no packet is received yet, as the worker is blocked trying
// to repost.
@@ -781,7 +728,7 @@ func TestReceivePostingIsFull(t *testing.T) {
// Flush tx queue, which will allow the first buffer to be reposted,
// and the second completion to be pulled.
c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that second packet completes.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
@@ -803,7 +750,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be indicated.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..35e5bff12 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -28,10 +29,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD eventfd.Eventfd
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,6 +146,12 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ t.eventFD.Notify()
+}
+
// getBuffer returns a memory region mapped to the full contents of the given
// file descriptor.
func getBuffer(fd int) ([]byte, error) {
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index d51c36f19..1c3b0887f 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -167,14 +167,17 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
- pkt.Data().DeleteFront(hlen)
+ if _, ok := pkt.Data().Consume(hlen); !ok {
+ panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
+ }
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
// ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
+ // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types
+ // require consuming the header, so we only call PullUp.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
@@ -242,7 +245,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
- // be modifying fields.
+ // be modifying fields. Both the ICMP header (with its type modified to
+ // EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
@@ -331,6 +335,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.echoReply.Increment()
+ // ICMP sockets expect the ICMP header to be present, so we don't consume
+ // the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
@@ -338,7 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
mtu := h.MTU()
code := h.Code()
- pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok {
+ panic("could not consume ICMPv4MinimumSize bytes")
+ }
switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index dda473e48..9b71738ae 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6c6107264..ff23d48e7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().DeleteFront(header.IPv6MinimumSize)
+ if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok {
+ panic("could not consume IPv6MinimumSize bytes")
+ }
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
+ if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok {
+ panic("could not consume IPv6FragmentHeaderSize bytes")
+ }
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
@@ -325,7 +329,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
received.invalid.Increment()
return
@@ -334,18 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
if err != nil {
networkMTU = 0
}
- pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
received.invalid.Increment()
return
}
code := header.ICMPv6(hdr).Code()
- pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e2d2cf907..600e805f8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -1537,19 +1537,22 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
- // Calculate the number of octets parsed from data. We want to remove all
- // the data except the unparsed portion located at the end, which its size
- // is extHdr.Buf.Size().
+ // Calculate the number of octets parsed from data. We want to consume all
+ // the data except the unparsed portion located at the end, whose size is
+ // extHdr.Buf.Size().
trim := pkt.Data().Size() - extHdr.Buf.Size()
// For unfragmented packets, extHdr still contains the transport header.
- // Get rid of it.
+ // Consume that too.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
trim += pkt.TransportHeader().View().Size()
- pkt.Data().DeleteFront(trim)
+ if _, ok := pkt.Data().Consume(trim); !ok {
+ stats.MalformedPacketsReceived.Increment()
+ return fmt.Errorf("could not consume %d bytes", trim)
+ }
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 4fb7e9adb..48f290187 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -45,17 +45,6 @@ const (
dirReply
)
-// Manipulation type for the connection.
-// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
-// DNAT at the same time.
-type manipType int
-
-const (
- manipNone manipType = iota
- manipSource
- manipDestination
-)
-
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
//
@@ -64,13 +53,21 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
// direction is the direction of the tuple.
direction direction
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +100,47 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
- manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // sourceManip indicates the packet's source is manipulated.
+ //
+ // +checklocks:mu
+ sourceManip bool
+ // destinationManip indicates the packet's destination is manipulated.
+ //
+ // +checklocks:mu
+ destinationManip bool
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
// TODO(gvisor.dev/issue/5939): do not use the ambient clock.
+ //
+ // +checklocks:mu
lastUsed time.Time `state:".(unixTime)"`
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -159,8 +153,9 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) {
if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
return
}
@@ -172,10 +167,16 @@ func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
+ return
+ }
+
+ switch dir {
+ case dirOriginal:
cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ case dirReply:
cn.tcb.UpdateStateInbound(tcpHeader)
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
}
}
@@ -200,18 +201,18 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
@@ -230,241 +231,212 @@ func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool)
return nil, false
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
+func (ct *ConnTrack) init() {
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.buckets = make([]bucket, numBuckets)
+}
+
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
netHeader := pkt.Network()
transportHeader, ok := getTransportHeader(pkt)
if !ok {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+ return nil
}
- return tupleID{
+ tid := tupleID{
srcAddr: netHeader.SourceAddress(),
srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
dstPort: transportHeader.DestinationPort(),
transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
- }, nil
-}
-
-func (ct *ConnTrack) init() {
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.buckets = make([]bucket, numBuckets)
-}
-
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := time.Now()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid, direction: dirOriginal},
+ reply: tuple{tupleID: tid.reply(), direction: dirReply},
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
+
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, time.Now())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocks:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
- }
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
- }
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
- transportHeader, ok := getTransportHeader(pkt)
- if !ok {
- return false
+ if cn.finalized {
+ return
}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+ if dnat {
+ if cn.destinationManip {
+ return
+ }
+ cn.destinationManip = true
+ } else {
+ if cn.sourceManip {
+ return
+ }
+ cn.sourceManip = true
}
- netHeader := pkt.Network()
-
- // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
- // validated if checksum offloading is off. It may require IP defrag if the
- // packets are fragmented.
-
- var newAddr tcpip.Address
- var newPort uint16
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
- updateSRCFields := false
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ }
+}
- switch hook {
- case Prerouting, Output:
- if conn.manip == manipDestination && dir == dirOriginal {
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- pkt.NatDone = true
- } else if conn.manip == manipSource && dir == dirReply {
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- pkt.NatDone = true
- }
- case Input, Postrouting:
- if conn.manip == manipSource && dir == dirOriginal {
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- } else if conn.manip == manipDestination && dir == dirReply {
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
- updateSRCFields = true
- pkt.NatDone = true
- }
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
+ if pkt.NatDone {
+ return
}
- if !pkt.NatDone {
- return false
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return
}
fullChecksum := false
updatePseudoHeader := false
+ dnat := false
switch hook {
case Prerouting:
// Packet came from outside the stack so it must have a checksum set
// already.
fullChecksum = true
updatePseudoHeader = true
+
+ dnat = true
case Input:
- case Output, Postrouting:
- // Calculate the TCP checksum and set it.
+ case Forward:
+ panic("should not handle packet in the forwarding hook")
+ case Output:
+ dnat = true
+ fallthrough
+ case Postrouting:
if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
@@ -472,62 +444,73 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
updatePseudoHeader = true
}
default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ panic(fmt.Sprintf("unrecognized hook = %d", hook))
}
- rewritePacket(
- netHeader,
- transportHeader,
- updateSRCFields,
- fullChecksum,
- updatePseudoHeader,
- newPort,
- newAddr,
- )
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
+ dir := pkt.tuple.direction
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ var tuple *tuple
+ switch dir {
+ case dirOriginal:
+ if dnat {
+ if !cn.destinationManip {
+ return tupleID{}, false
+ }
+ } else if !cn.sourceManip {
+ return tupleID{}, false
+ }
- // Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
- // Update connection state.
- conn.updateLocked(pkt, hook)
+ tuple = &cn.reply
+ case dirReply:
+ if dnat {
+ if !cn.sourceManip {
+ return tupleID{}, false
+ }
+ } else if !cn.destinationManip {
+ return tupleID{}, false
+ }
- return false
-}
+ tuple = &cn.original
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
+ }
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = time.Now()
+ // Update connection state.
+ cn.updateLocked(pkt, dir)
- switch pkt.TransportProtocolNumber {
- case header.TCPProtocolNumber, header.UDPProtocolNumber:
- default:
- // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
- // connections.
+ return tuple.id(), true
+ }()
+ if !performManip {
return
}
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
+ newPort := tid.dstPort
+ newAddr := tid.dstAddr
+ if dnat {
+ newPort = tid.srcPort
+ newAddr = tid.srcAddr
}
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(pkt, hook)
- ct.insertConn(conn)
+
+ rewritePacket(
+ pkt.Network(),
+ transportHeader,
+ !dnat,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
+
+ pkt.NatDone = true
}
// bucket gets the conntrack bucket for a tupleID.
@@ -579,14 +562,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -611,41 +595,45 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool {
if !tuple.conn.timedOut(now) {
return false
}
// To maintain lock order, we can only reap these tuples if the reply
// appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ replyBktID := ct.bucket(tuple.id().reply())
+ if bktID > replyBktID {
return true
}
// Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
}
// We have the buckets locked and can remove both tuples.
+ bkt.tuples.Remove(tuple)
+ return true
+}
+
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ b.tuples.Remove(&tuple.conn.reply)
} else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
-
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+ b.tuples.Remove(&tuple.conn.original)
}
-
- return true
}
func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
@@ -659,14 +647,19 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if !t.conn.destinationManip {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 74c9075b4..5808be685 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -271,7 +271,18 @@ const (
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
- return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
}
// CheckInput performs the input hook on the packet.
@@ -281,7 +292,22 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
- return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
}
// CheckForward performs the forward hook on the packet.
@@ -291,6 +317,9 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
}
@@ -301,7 +330,18 @@ func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
- return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
}
// CheckPostrouting performs the postrouting hook on the packet.
@@ -310,8 +350,38 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
-func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool {
- return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ return !it.modified
}
// check runs pkt through the rules for hook. It returns true when the packet
@@ -320,20 +390,8 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName str
//
// Precondition: The packet's network and transport header must be set.
func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
- return true
- }
- // Many users never configure iptables. Spare them the cost of rule
- // traversal if rules have never been set.
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -361,7 +419,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -377,21 +435,6 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -431,7 +474,9 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// Precondition: The packets' network and transport header must be set.
func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Output, pkts, r, outNicName)
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ })
}
// CheckPostroutingPackets performs the postrouting hook on the packets.
@@ -439,21 +484,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa
// Returns a map of packets that must be dropped.
//
// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Postrouting, pkts, r, outNicName)
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ })
}
-// checkPackets runs pkts through the rules for hook and returns a map of
-// packets that should not go forward.
-//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-//
-// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok {
+ if ok := f(pkt); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -543,7 +583,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, addressEP)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
index 529e02a07..3d3c39c20 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -26,11 +26,15 @@ type unixTime struct {
// saveLastUsed is invoked by stateify.
func (cn *conn) saveLastUsed() unixTime {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
}
// loadLastUsed is invoked by stateify.
func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
cn.lastUsed = time.Unix(unix.second, unix.nano)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index e8806ebdb..85490e2d4 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,10 +79,49 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
+// DNATTarget modifies the destination port/IP of packets.
+type DNATTarget struct {
+ // The new destination address for packets.
+ //
+ // Immutable.
+ Addr tcpip.Address
+
+ // The new destination port for packets.
+ //
+ // Immutable.
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with.
+ //
+ // Immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Prerouting, Output:
+ case Input, Forward, Postrouting:
+ panic(fmt.Sprintf("%s not supported for DNAT", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */)
+
+}
+
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
@@ -97,7 +136,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -105,16 +144,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
var address tcpip.Address
@@ -132,43 +161,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
- }
-
- return RuleAccept, 0
+ return natAction(pkt, hook, r, rt.Port, address, true /* dnat */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -181,15 +174,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
- // Sanity check.
- if st.NetworkProtocol != pkt.NetworkProtocolNumber {
- panic(fmt.Sprintf(
- "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
- st.NetworkProtocol, pkt.NetworkProtocolNumber))
- }
-
+func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -200,6 +185,37 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleDrop, 0
}
+ t := pkt.tuple
+ if t == nil {
+ return RuleDrop, 0
+ }
+
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ default:
+ panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber))
+ }
+ }
+
+ t.conn.performNAT(pkt, hook, r, port, address, dnat)
+ return RuleAccept, 0
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -208,31 +224,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- port := st.Port
+ return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */)
+}
- if port == 0 {
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- if port == 0 {
- port = header.UDP(pkt.TransportHeader().View()).SourcePort()
- }
- case header.TCPProtocolNumber:
- if port == 0 {
- port = header.TCP(pkt.TransportHeader().View()).SourcePort()
- }
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Set up conection for matching NAT rule. Only the first packet of the
- // connection comes here. Other packets will be manipulated in connection
- // tracking.
- //
- // Does nothing if the protocol does not support connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
}
- return RuleAccept, 0
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
+ return RuleDrop, 0
+ }
+
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 976194124..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index bf248ef20..888a8bd9d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -143,6 +143,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -302,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -329,13 +332,8 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
return newPk
}
@@ -367,12 +365,7 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu
newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- newPk.NatDone = pk.NatDone
+ newPk.tuple = pk.tuple
return newPk
}
@@ -425,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 87b023445..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
@@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index cd4137794..c23e91702 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index dc7289441..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -289,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index d45a2c05c..460a6afaf 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -423,9 +423,9 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -471,10 +471,10 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns KUID of the packet.
+ // KUID returns KUID of the packet.
KUID() uint32
- // GID returns KGID of the packet.
+ // KGID returns KGID of the packet.
KGID() uint32
}
@@ -1245,11 +1245,11 @@ type Route struct {
// String implements the fmt.Stringer interface.
func (r Route) String() string {
var out strings.Builder
- fmt.Fprintf(&out, "%s", r.Destination)
+ _, _ = fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
- fmt.Fprintf(&out, " via %s", r.Gateway)
+ _, _ = fmt.Fprintf(&out, " via %s", r.Gateway)
}
- fmt.Fprintf(&out, " nic %d", r.NIC)
+ _, _ = fmt.Fprintf(&out, " nic %d", r.NIC)
return out.String()
}
@@ -1286,7 +1286,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value(name ...string) uint64 {
+func (s *StatCounter) Value(...string) uint64 {
return s.count.Load()
}
@@ -1865,6 +1865,10 @@ type TCPStats struct {
// SegmentsAckedWithDSACK is the number of segments acknowledged with
// DSACK.
SegmentsAckedWithDSACK *StatCounter
+
+ // SpuriousRecovery is the number of times the connection entered loss
+ // recovery spuriously.
+ SpuriousRecovery *StatCounter
}
// UDPStats collects UDP-specific stats.
diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go
new file mode 100644
index 000000000..1953e24a1
--- /dev/null
+++ b/pkg/tcpip/tcpip_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "time"
+)
+
+func (c *ControlMessages) saveTimestamp() int64 {
+ return c.Timestamp.UnixNano()
+}
+
+func (c *ControlMessages) loadTimestamp(nsec int64) {
+ c.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index bdf4a64b9..7f872c271 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1162,19 +1162,19 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
}
}
-func TestSNAT(t *testing.T) {
- const listenPort = 8080
+func TestNAT(t *testing.T) {
+ const listenPort uint16 = 8080
type endpointAndAddresses struct {
- serverEP tcpip.Endpoint
- serverAddr tcpip.Address
- serverReadableCH chan struct{}
-
- clientEP tcpip.Endpoint
- clientAddr tcpip.Address
- clientReadableCH chan struct{}
-
- nattedClientAddr tcpip.Address
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.FullAddress
+ serverReadableCH chan struct{}
+ serverConnectAddr tcpip.Address
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ clientConnectAddr tcpip.FullAddress
}
newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
@@ -1195,71 +1195,247 @@ func TestSNAT(t *testing.T) {
return ep, ch
}
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ table := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := table.BuiltinChains[hook]
+ table.Rules[ruleIdx].Filter = filter
+ table.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Prerouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ target)
+ }
+
+ setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Postrouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ target)
+ }
+
+ type natType struct {
+ name string
+ setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address)
+ }
+
+ snatTypes := []natType{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+ dnatTypes := []natType{
+ {
+ name: "Redirect",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort})
+ },
+ },
+ {
+ name: "DNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort})
+ },
+ },
+ }
+
tests := []struct {
- name string
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ // Setups up the stacks in such a way that:
+ //
+ // - Host2 is the client for all tests.
+ // - Host1 is the server when performing SNAT
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
+ // - Router is the server when performing DNAT (client will still attempt to
+ // send packets to Host1).
+ // + NAT will transform client-originating packets' destination addresses
+ // to the router's NIC2's address.
epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ natTypes []natType
}{
{
- name: "IPv4 host1 server with host2 client",
+ name: "IPv4 SNAT",
+ netProto: ipv4.ProtocolNumber,
epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
t.Helper()
- ipt := routerStack.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
- filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err)
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
}
-
- ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
clientEP: ep2,
clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
clientReadableCH: ep2WECH,
-
- nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
}
},
+ natTypes: snatTypes,
},
{
- name: "IPv6 host1 server with host2 client",
+ name: "IPv4 DNAT",
+ netProto: ipv4.ProtocolNumber,
epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
t.Helper()
- ipt := routerStack.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
- filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err)
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
}
+ serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
- ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: dnatTypes,
+ },
+ {
+ name: "IPv6 SNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
clientEP: ep2,
clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: snatTypes,
+ },
+ {
+ name: "IPv6 DNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
- nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
}
},
+ natTypes: dnatTypes,
},
}
@@ -1328,116 +1504,121 @@ func TestSNAT(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
for _, subTest := range subTests {
t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- }
+ for _, natType := range test.natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
- if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr)
- if subTest.setupServer != nil {
- subTest.setupServer(t, epsAndAddrs.serverEP)
- }
- {
- err := epsAndAddrs.clientEP.Connect(serverAddr)
- if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
- }
- }
- nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
- if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
- } else {
- nattedClientAddr.Port = addr.Port
- }
+ if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
- serverEP := epsAndAddrs.serverEP
- serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
- defer ep.Close()
- serverEP = ep
- serverCH = ch
- }
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff)
+ }
+ }
+ serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ serverConnectAddr.Port = addr.Port
+ }
- write := func(ep tcpip.Endpoint, data []byte) {
- t.Helper()
-
- var r bytes.Reader
- r.Reset(data)
- var wOpts tcpip.WriteOptions
- n, err := ep.Write(&r, wOpts)
- if err != nil {
- t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
- }
- }
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
- t.Helper()
-
- var buf bytes.Buffer
- var res tcpip.ReadResult
- for {
- var err tcpip.Error
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err = ep.Read(&buf, opts)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
}
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
}
- break
- }
-
- readResult := tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- }
- if subTest.needRemoteAddr {
- readResult.RemoteAddr = expectedFrom
- }
- if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
- {
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data)
- read(serverCH, serverEP, data, nattedClientAddr)
- }
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, serverConnectAddr)
+ }
- {
- data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
- write(serverEP, data)
- read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr)
+ }
+ })
}
})
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index bb0db9f70..31579a896 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -180,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
},
}
if opts.NeedRemoteAddr {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 689427d53..80eef39e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -182,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.receivedAt.UnixNano(),
+ Timestamp: packet.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -409,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index bfef75da7..ce76774af 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -49,6 +49,7 @@ type rawPacket struct {
receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
}
// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
@@ -202,12 +203,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.receivedAt.UnixNano(),
+ Timestamp: pkt.receivedAt,
},
}
if opts.NeedRemoteAddr {
res.RemoteAddr = pkt.senderAddr
}
+ switch netProto := e.net.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceivePacketInfo() {
+ res.ControlMessages.HasIPPacketInfo = true
+ res.ControlMessages.PacketInfo = pkt.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ res.ControlMessages.HasIPv6PacketInfo = true
+ res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: pkt.packetInfo.NIC,
+ Addr: pkt.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
+ }
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
@@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return false
}
- srcAddr := pkt.Network().SourceAddress()
+ net := pkt.Network()
+ dstAddr := net.DestinationAddress()
+ srcAddr := net.SourceAddress()
info := e.net.Info()
switch state := e.net.State(); state {
@@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
// If bound to an address, only accept data for that address.
- if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ if info.BindAddr != "" && info.BindAddr != dstAddr {
return false
}
default:
@@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
NIC: pkt.NICID,
Addr: srcAddr,
},
+ packetInfo: tcpip.IPPacketInfo{
+ // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast
+ // address. LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ LocalAddr: dstAddr,
+ DestinationAddr: dstAddr,
+ NIC: pkt.NICID,
+ },
}
// Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
@@ -483,10 +511,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// overlapping slices.
var combinedVV buffer.VectorisedView
if info.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
+ networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader))
+ headers = append(headers, networkHeader...)
+ headers = append(headers, transportHeader...)
combinedVV = headers.ToVectorisedView()
} else {
combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 7115d0a12..caf14b0dc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -15,12 +15,12 @@
package tcp
import (
+ "container/list"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -100,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a map of all endpoints for which a handshake is
- // in progress.
- pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 {
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- protocol: protocol,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -265,7 +252,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
@@ -275,8 +261,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -295,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -336,38 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for _, n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -378,9 +332,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -444,101 +395,30 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// handleSynSegment is called in its own goroutine once the listening endpoint
-// receives a SYN segment. It is responsible for completing the handshake and
-// queueing the new endpoint for acceptance.
-//
-// A limited number of these goroutines are allowed before TCP starts using SYN
-// cookies to accept connections.
-//
-// +checklocks:e.mu
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
- defer s.decRef()
-
- h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
- if err != nil {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
- }
-
- go func() {
- // Note that startHandshake returns a locked endpoint. The
- // force call here just makes it so.
- if err := h.complete(); err != nil { // +checklocksforce
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return
- }
- ctx.cleanupCompletedHandshake(h)
- h.ep.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- // Deliver the endpoint to the accept queue.
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- // If the listener has transitioned out of the listen state (accepted
- // is the zero value), the new endpoint is reset instead.
- return false
- }
- if e.accepted.acceptQueueIsFullLocked() {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(h.ep)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return true
- }
- }()
-
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- h.ep.notifyProtocolGoroutine(notifyReset)
- }
- }()
-
- return nil
-}
-
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := e.accepted.acceptQueueIsFullLocked()
+ full := e.acceptQueue.isFull()
e.acceptMu.Unlock()
return full
}
-func (a *accepted) acceptQueueIsFullLocked() bool {
- return a.endpoints.Len() == a.cap
+// +stateify savable
+type acceptQueue struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
+
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
+
+ // capacity is the maximum number of endpoints that can be in endpoints.
+ capacity int
+}
+
+func (a *acceptQueue) isFull() bool {
+ return a.endpoints.Len() == a.capacity
}
// handleListenSegment is called when a listening endpoint receives a segment
@@ -571,20 +451,96 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- alwaysUseSynCookies := func() bool {
+ opts := parseSynSegmentOptions(s)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
}
- return bool(alwaysUseSynCookies)
- }()
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
- opts := parseSynSegmentOptions(s)
- if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
- s.incRef()
- atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, opts)
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ return true, nil
+ }
+
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return false, err
+ }
+
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
+
+ go func() {
+ defer func() {
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
+ }()
+
+ // Note that startHandshake returns a locked endpoint. The force call
+ // here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ // Deliver the endpoint to the accept queue.
+ //
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ // The listener is transitioning out of the Listen state; bail.
+ if e.acceptQueue.capacity == 0 {
+ return false
+ }
+ if e.acceptQueue.isFull() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.acceptQueue.endpoints.PushBack(h.ep)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
+ }()
+
+ return false, nil
+ }()
+ if err != nil {
+ return err
}
+ if !useSynCookies {
+ return nil
+ }
+
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -627,23 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- // Keep hold of acceptMu until the new endpoint is in the accept queue (or
- // if there is an error), to guarantee that we will keep our spot in the
- // queue even if another handshake from the syn queue completes.
- e.acceptMu.Lock()
- if e.accepted.acceptQueueIsFullLocked() {
- // Silently drop the ack as the application can't accept
- // the connection at this point. The ack will be
- // retransmitted by the sender anyway and we can
- // complete the connection at the time of retransmit if
- // the backlog has space.
- e.acceptMu.Unlock()
- e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -659,7 +598,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
- e.acceptMu.Unlock()
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -680,6 +618,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// ACK was received from the sender.
return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
+
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.acceptQueue.isFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.acceptMu.Unlock()
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
rcvdSynOptions := header.TCPSynOptions{
@@ -769,7 +725,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
// Deliver the endpoint to the accept queue.
- e.accepted.endpoints.PushBack(n)
+ e.acceptQueue.endpoints.PushBack(n)
e.acceptMu.Unlock()
e.waiterQueue.Notify(waiter.ReadableEvents)
@@ -789,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
- // Mark endpoint as closed. This will prevent goroutines running
- // handleSynSegment() from attempting to queue new connections
- // to the endpoint.
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 407ab2664..066ffe051 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,7 +15,6 @@
package tcp
import (
- "container/list"
"encoding/binary"
"fmt"
"io"
@@ -205,6 +204,8 @@ type SACKInfo struct {
}
// ReceiveErrors collect segment receive errors within transport layer.
+//
+// +stateify savable
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -234,6 +235,8 @@ type ReceiveErrors struct {
}
// SendErrors collect segment send errors within the transport layer.
+//
+// +stateify savable
type SendErrors struct {
tcpip.SendErrors
@@ -257,6 +260,8 @@ type SendErrors struct {
}
// Stats holds statistics about the endpoint.
+//
+// +stateify savable
type Stats struct {
// SegmentsReceived is the number of TCP segments received that
// the transport layer successfully parsed.
@@ -311,18 +316,6 @@ type rcvQueueInfo struct {
rcvQueue segmentList `state:"wait"`
}
-// +stateify savable
-type accepted struct {
- // NB: this could be an endpointList, but ilist only permits endpoints to
- // belong to one list at a time, and endpoints are already stored in the
- // dispatcher's list.
- endpoints list.List `state:".([]*endpoint)"`
-
- // cap is the maximum number of endpoints that can be in the accepted endpoint
- // list.
- cap int
-}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -338,7 +331,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> Protects e.accepted.
+// e.acceptMu -> Protects e.acceptQueue.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -502,10 +495,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
@@ -579,7 +568,7 @@ type endpoint struct {
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
// +checklocks:acceptMu
- accepted accepted
+ acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -612,8 +601,7 @@ type endpoint struct {
gso stack.GSO
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats Stats `state:"nosave"`
+ stats Stats
// tcpLingerTimeout is the maximum amount of a time a socket
// a socket stays in TIME_WAIT state before being marked
@@ -825,10 +813,9 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
waiterQueue: waiterQueue,
state: uint32(StateInitial),
keepalive: keepalive{
- // Linux defaults.
- idle: 2 * time.Hour,
- interval: 75 * time.Second,
- count: 9,
+ idle: DefaultKeepaliveIdle,
+ interval: DefaultKeepaliveInterval,
+ count: DefaultKeepaliveCount,
},
uniqueID: s.UniqueID(),
txHash: s.Rand().Uint32(),
@@ -910,7 +897,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if e.accepted.endpoints.Len() != 0 {
+ if e.acceptQueue.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1093,20 +1080,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.accepted
- e.accepted = accepted{}
- e.acceptMu.Unlock()
-
- if acceptedCopy == (accepted{}) {
- return
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
}
-
- e.acceptCond.Broadcast()
-
+ e.acceptQueue.pendingEndpoints = nil
// Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
+ e.acceptQueue.endpoints.Init()
+ e.acceptMu.Unlock()
+
+ e.acceptCond.Broadcast()
+
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -2498,22 +2485,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.accepted == (accepted{}) {
- // listen is called after shutdown.
- e.accepted.cap = backlog
- e.shutdownFlags = 0
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcvQueueInfo.RcvClosed = false
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- } else {
- // Adjust the size of the backlog iff we can fit
- // existing pending connections into the new one.
- if e.accepted.endpoints.Len() > backlog {
- return &tcpip.ErrInvalidEndpointState{}
- }
- e.accepted.cap = backlog
+
+ // Adjust the size of the backlog iff we can fit
+ // existing pending connections into the new one.
+ if e.acceptQueue.endpoints.Len() > backlog {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+ e.acceptQueue.capacity = backlog
+
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
}
+ e.shutdownFlags = 0
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+
// Notify any blocked goroutines that they can attempt to
// deliver endpoints again.
e.acceptCond.Broadcast()
@@ -2548,8 +2536,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.accepted == (accepted{}) {
- e.accepted.cap = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+ if e.acceptQueue.capacity == 0 {
+ e.acceptQueue.capacity = backlog
}
e.acceptMu.Unlock()
@@ -2589,8 +2580,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
// Get the new accepted endpoint.
var n *endpoint
e.acceptMu.Lock()
- if element := e.accepted.endpoints.Front(); element != nil {
- n = e.accepted.endpoints.Remove(element).(*endpoint)
+ if element := e.acceptQueue.endpoints.Front(); element != nil {
+ n = e.acceptQueue.endpoints.Remove(element).(*endpoint)
}
e.acceptMu.Unlock()
if n == nil {
@@ -3007,6 +2998,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
+ s.Sender.RetransmitTS = e.snd.retransmitTS
+ s.Sender.SpuriousRecovery = e.snd.spuriousRecovery
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 381f4474d..94072a115 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() {
}
// saveEndpoints is invoked by stateify.
-func (a *accepted) saveEndpoints() []*endpoint {
+func (a *acceptQueue) saveEndpoints() []*endpoint {
acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
acceptedEndpoints[i] = e.Value.(*endpoint)
@@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint {
}
// loadEndpoints is invoked by stateify.
-func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
for _, ep := range acceptedEndpoints {
a.endpoints.PushBack(ep)
}
@@ -252,7 +252,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
bind()
e.acceptMu.Lock()
- backlog := e.accepted.cap
+ backlog := e.acceptQueue.capacity
e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index e4410ad93..f122ea009 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -66,6 +66,18 @@ const (
// DefaultSynRetries is the default value for the number of SYN retransmits
// before a connect is aborted.
DefaultSynRetries = 6
+
+ // DefaultKeepaliveIdle is the idle time for a connection before keep-alive
+ // probes are sent.
+ DefaultKeepaliveIdle = 2 * time.Hour
+
+ // DefaultKeepaliveInterval is the time between two successive keep-alive
+ // probes.
+ DefaultKeepaliveInterval = 75 * time.Second
+
+ // DefaultKeepaliveCount is the number of keep-alive probes that are sent
+ // before declaring the connection dead.
+ DefaultKeepaliveCount = 9
)
const (
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 2fabf1594..4377f07a0 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -144,6 +144,15 @@ type sender struct {
// probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
probeTimer timer `state:"nosave"`
probeWaker sleep.Waker `state:"nosave"`
+
+ // spuriousRecovery indicates whether the sender entered recovery
+ // spuriously as described in RFC3522 Section 3.2.
+ spuriousRecovery bool
+
+ // retransmitTS is the timestamp at which the sender sends retransmitted
+ // segment after entering an RTO for the first time as described in
+ // RFC3522 Section 3.2.
+ retransmitTS uint32
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -425,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // Initialize the variables used to detect spurious recovery after
+ // entering RTO.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
// TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
// when writeList is empty. Remove this once we have a proper fix for this
// issue.
@@ -495,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
@@ -958,6 +978,13 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
+ // Initialize the variables used to detect spurious recovery after
+ // entering recovery.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
@@ -972,6 +999,11 @@ func (s *sender) enterRecovery() {
s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
s.FastRecovery.HighRxt = s.SndUna
s.FastRecovery.RescueRxt = s.SndUna
+
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1147,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool {
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
+// Returns true when the DSACK block has been detected in the received ACK.
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
-func (s *sender) walkSACK(rcvdSeg *segment) {
+func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
// Look for DSACK block.
+ hasDSACK := false
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
@@ -1167,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.setDSACKSeen(true)
idx = 1
n--
+ hasDSACK = true
}
if n == 0 {
- return
+ return hasDSACK
}
// Sort the SACK blocks. The first block is the most recent unacked
@@ -1193,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
seg = seg.Next()
}
}
+ return hasDSACK
}
// checkDSACK checks if a DSACK is reported.
@@ -1239,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool {
return false
}
+func (s *sender) recordRetransmitTS() {
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2
+ //
+ // The Eifel detection algorithm is used, only upon initiation of loss
+ // recovery, i.e., when either the timeout-based retransmit or the fast
+ // retransmit is sent. The Eifel detection algorithm MUST NOT be
+ // reinitiated after loss recovery has already started. In particular,
+ // it must not be reinitiated upon subsequent timeouts for the same
+ // segment, and not upon retransmitting segments other than the oldest
+ // outstanding segment, e.g., during selective loss recovery.
+ if s.inRecovery() {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ //
+ // Set a "RetransmitTS" variable to the value of the Timestamp Value
+ // field of the Timestamps option included in the retransmit sent when
+ // loss recovery is initiated. A TCP sender must ensure that
+ // RetransmitTS does not get overwritten as loss recovery progresses,
+ // e.g., in case of a second timeout and subsequent second retransmit of
+ // the same octet.
+ s.retransmitTS = s.ep.tsValNow()
+}
+
+func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
+ // Return if the sender has already detected spurious recovery.
+ if s.spuriousRecovery {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4
+ //
+ // If the value of the Timestamp Echo Reply field of the acceptable ACK's
+ // Timestamps option is smaller than the value of RetransmitTS, then
+ // proceed to next step, else return.
+ if tsEchoReply >= s.retransmitTS {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If the acceptable ACK carries a DSACK option [RFC2883], then return.
+ if hasDSACK {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If during the lifetime of the TCP connection the TCP sender has
+ // previously received an ACK with a DSACK option, or the acceptable ACK
+ // does not acknowledge all outstanding data, then proceed to next step,
+ // else return.
+ numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value()
+ if numDSACK == 0 && s.SndUna == s.SndNxt {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6
+ //
+ // If the loss recovery has been initiated with a timeout-based
+ // retransmit, then set
+ // SpuriousRecovery <- SPUR_TO (equal 1),
+ // else set
+ // SpuriousRecovery <- dupacks+1
+ // Set the spurious recovery variable to true as we do not differentiate
+ // between fast, SACK or RTO recovery.
+ s.spuriousRecovery = true
+ s.ep.stack.Stats().TCP.SpuriousRecovery.Increment()
+}
+
+// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state.
+func (s *sender) inRecovery() bool {
+ if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery {
+ return true
+ }
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1254,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Insert SACKBlock information into our scoreboard.
+ hasDSACK := false
if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
@@ -1288,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- s.walkSACK(rcvdSeg)
+ hasDSACK = s.walkSACK(rcvdSeg)
}
s.SetPipe()
}
@@ -1418,6 +1534,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
+ // Detect if the sender entered recovery spuriously.
+ if s.inRecovery() {
+ s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr)
+ }
+
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
if !s.FastRecovery.Active {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index c35db7c95..0d36d0dd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -1059,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) {
for i := 0; i < numPkts; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
}
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ // Expect retransmission of last packet due to TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize)
+
+ // SACK for first and last packet.
+ start := c.IRS.Add(seqnum.Size(maxPayload))
end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
+ dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
var info tcpip.TCPInfoOption
if err := c.EP.GetSockOpt(&info); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6255355bb..896249d2d 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -23,6 +23,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) {
t.Error(err)
}
}
+
+func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) {
+ t.Helper()
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) {
+ payloadLen := uint32(len(tcpHdr.Payload()))
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
+ checker.TCPAckNum(context.TestInitialSequenceNumber+1),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ pdata := data[bytesRead : bytesRead+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+}
+
+func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ return tsOpt[:]
+}
+
+func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Expect #5 segment with TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Expect #1 segment because of RTO.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.RTORecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ numAck := 0
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if numAck < 3 {
+ numAck++
+ return
+ }
+
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.SACKRecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ // Acknowledge the data with DSACK for #1 segment.
+ start = c.IRS.Add(maxPayload + 1)
+ end = start.Add(2 * maxPayload)
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
+
+ verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 049957b81..39b1e08c0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -233,7 +233,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
}
switch p.netProto {
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 234125c38..8902be2d3 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/eventfd",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go
index 40fa72925..0dc0c37bd 100644
--- a/pkg/unet/unet.go
+++ b/pkg/unet/unet.go
@@ -23,6 +23,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -55,15 +56,6 @@ func socket(packet bool) (int, error) {
return fd, nil
}
-// eventFD returns a new event FD with initial value 0.
-func eventFD() (int, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return -1, e
- }
- return int(f), nil
-}
-
// Socket is a connected unix domain socket.
type Socket struct {
// gate protects use of fd.
@@ -78,7 +70,7 @@ type Socket struct {
// efd is an event FD that is signaled when the socket is closing.
//
// efd is immutable and remains valid until Close/Release.
- efd int
+ efd eventfd.Eventfd
// race is an atomic variable used to avoid triggering the race
// detector. See comment in SocketPair below.
@@ -95,7 +87,7 @@ func NewSocket(fd int) (*Socket, error) {
return nil, err
}
- efd, err := eventFD()
+ efd, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -110,16 +102,14 @@ func NewSocket(fd int) (*Socket, error) {
// closing the event FD.
func (s *Socket) finish() error {
// Signal any blocked or future polls.
- //
- // N.B. eventfd writes must be 8 bytes.
- if _, err := unix.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil {
+ if err := s.efd.Notify(); err != nil {
return err
}
// Close the gate, blocking until all FD users leave.
s.gate.Close()
- return unix.Close(s.efd)
+ return s.efd.Close()
}
// Close closes the socket.
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
index f0bf93ddd..ea281fec3 100644
--- a/pkg/unet/unet_unsafe.go
+++ b/pkg/unet/unet_unsafe.go
@@ -43,7 +43,7 @@ func (s *Socket) wait(write bool) error {
},
{
// The eventfd, signaled when we are closing.
- Fd: int32(s.efd),
+ Fd: int32(s.efd.FD()),
Events: unix.POLLIN,
},
}