summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/capability.go84
-rw-r--r--pkg/abi/linux/mm.go9
-rw-r--r--pkg/abi/linux/prctl.go7
-rw-r--r--pkg/abi/linux/socket.go34
-rw-r--r--pkg/memutil/BUILD11
-rw-r--r--pkg/memutil/memutil_unsafe.go (renamed from pkg/sentry/memutil/memutil_unsafe.go)3
-rw-r--r--pkg/procid/BUILD (renamed from pkg/sentry/platform/procid/BUILD)4
-rw-r--r--pkg/procid/procid.go (renamed from pkg/sentry/platform/procid/procid.go)0
-rw-r--r--pkg/procid/procid_amd64.s (renamed from pkg/sentry/platform/procid/procid_amd64.s)0
-rw-r--r--pkg/procid/procid_arm64.s (renamed from pkg/sentry/platform/procid/procid_arm64.s)0
-rw-r--r--pkg/procid/procid_net_test.go (renamed from pkg/sentry/platform/procid/procid_net_test.go)0
-rw-r--r--pkg/procid/procid_test.go (renamed from pkg/sentry/platform/procid/procid_test.go)0
-rw-r--r--pkg/sentry/context/contexttest/BUILD2
-rw-r--r--pkg/sentry/context/contexttest/contexttest.go2
-rw-r--r--pkg/sentry/fs/dirent.go2
-rw-r--r--pkg/sentry/fs/file.go24
-rw-r--r--pkg/sentry/fs/gofer/socket.go16
-rw-r--r--pkg/sentry/fs/host/socket.go73
-rw-r--r--pkg/sentry/fs/host/socket_test.go156
-rw-r--r--pkg/sentry/fs/inode.go4
-rw-r--r--pkg/sentry/fs/inode_overlay.go12
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/inode.go40
-rw-r--r--pkg/sentry/fs/proc/net.go34
-rw-r--r--pkg/sentry/fs/proc/task.go17
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go2
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go25
-rw-r--r--pkg/sentry/hostmm/BUILD18
-rw-r--r--pkg/sentry/hostmm/cgroup.go111
-rw-r--r--pkg/sentry/hostmm/hostmm.go130
-rw-r--r--pkg/sentry/kernel/BUILD13
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go2
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/kernel/kernel.go55
-rw-r--r--pkg/sentry/kernel/pipe/node.go12
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go8
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go39
-rw-r--r--pkg/sentry/kernel/ptrace.go17
-rw-r--r--pkg/sentry/kernel/syscalls.go60
-rw-r--r--pkg/sentry/kernel/table_test.go8
-rw-r--r--pkg/sentry/kernel/task.go7
-rw-r--r--pkg/sentry/kernel/task_exec.go7
-rw-r--r--pkg/sentry/kernel/task_identity.go24
-rw-r--r--pkg/sentry/kernel/task_sched.go4
-rw-r--r--pkg/sentry/memutil/BUILD14
-rw-r--r--pkg/sentry/mm/lifecycle.go6
-rw-r--r--pkg/sentry/mm/metadata.go30
-rw-r--r--pkg/sentry/mm/mm.go12
-rw-r--r--pkg/sentry/mm/syscalls.go63
-rw-r--r--pkg/sentry/mm/vma.go3
-rw-r--r--pkg/sentry/pgalloc/BUILD3
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go63
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/machine.go2
-rw-r--r--pkg/sentry/platform/ptrace/BUILD2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go21
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go37
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/socket/control/control.go12
-rw-r--r--pkg/sentry/socket/epsocket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go96
-rw-r--r--pkg/sentry/socket/epsocket/provider.go9
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go61
-rw-r--r--pkg/sentry/socket/netlink/provider.go9
-rw-r--r--pkg/sentry/socket/netlink/socket.go17
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD1
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go31
-rw-r--r--pkg/sentry/socket/socket.go21
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go29
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go20
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go30
-rw-r--r--pkg/sentry/socket/unix/unix.go80
-rw-r--r--pkg/sentry/strace/socket.go14
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/error.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go756
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go312
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go145
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go4
-rw-r--r--pkg/sentry/syscalls/syscalls.go88
-rw-r--r--pkg/sentry/time/BUILD2
-rw-r--r--pkg/sentry/usage/BUILD2
-rw-r--r--pkg/sentry/usage/memory.go2
-rw-r--r--pkg/sentry/usermem/usermem.go36
-rw-r--r--pkg/sentry/usermem/usermem_test.go15
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go168
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go2
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go10
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go4
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go2
-rw-r--r--pkg/tcpip/stack/route.go10
-rw-r--r--pkg/tcpip/stack/transport_test.go4
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go50
-rw-r--r--pkg/tcpip/transport/tcp/connect.go52
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go3
-rw-r--r--pkg/tcpip/transport/tcp/cubic_state.go (renamed from pkg/sentry/memutil/memutil.go)19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go223
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go44
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go26
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go37
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go6
-rw-r--r--pkg/tcpip/transport/tcp/snd.go16
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go519
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go963
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go85
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/urpc/urpc.go2
117 files changed, 3492 insertions, 1965 deletions
diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go
index c120cac64..65dd77e6e 100644
--- a/pkg/abi/linux/capability.go
+++ b/pkg/abi/linux/capability.go
@@ -69,6 +69,90 @@ func (cp Capability) Ok() bool {
return cp >= 0 && cp <= MaxCapability
}
+// String returns the capability name.
+func (cp Capability) String() string {
+ switch cp {
+ case CAP_CHOWN:
+ return "CAP_CHOWN"
+ case CAP_DAC_OVERRIDE:
+ return "CAP_DAC_OVERRIDE"
+ case CAP_DAC_READ_SEARCH:
+ return "CAP_DAC_READ_SEARCH"
+ case CAP_FOWNER:
+ return "CAP_FOWNER"
+ case CAP_FSETID:
+ return "CAP_FSETID"
+ case CAP_KILL:
+ return "CAP_KILL"
+ case CAP_SETGID:
+ return "CAP_SETGID"
+ case CAP_SETUID:
+ return "CAP_SETUID"
+ case CAP_SETPCAP:
+ return "CAP_SETPCAP"
+ case CAP_LINUX_IMMUTABLE:
+ return "CAP_LINUX_IMMUTABLE"
+ case CAP_NET_BIND_SERVICE:
+ return "CAP_NET_BIND_SERVICE"
+ case CAP_NET_BROADCAST:
+ return "CAP_NET_BROADCAST"
+ case CAP_NET_ADMIN:
+ return "CAP_NET_ADMIN"
+ case CAP_NET_RAW:
+ return "CAP_NET_RAW"
+ case CAP_IPC_LOCK:
+ return "CAP_IPC_LOCK"
+ case CAP_IPC_OWNER:
+ return "CAP_IPC_OWNER"
+ case CAP_SYS_MODULE:
+ return "CAP_SYS_MODULE"
+ case CAP_SYS_RAWIO:
+ return "CAP_SYS_RAWIO"
+ case CAP_SYS_CHROOT:
+ return "CAP_SYS_CHROOT"
+ case CAP_SYS_PTRACE:
+ return "CAP_SYS_PTRACE"
+ case CAP_SYS_PACCT:
+ return "CAP_SYS_PACCT"
+ case CAP_SYS_ADMIN:
+ return "CAP_SYS_ADMIN"
+ case CAP_SYS_BOOT:
+ return "CAP_SYS_BOOT"
+ case CAP_SYS_NICE:
+ return "CAP_SYS_NICE"
+ case CAP_SYS_RESOURCE:
+ return "CAP_SYS_RESOURCE"
+ case CAP_SYS_TIME:
+ return "CAP_SYS_TIME"
+ case CAP_SYS_TTY_CONFIG:
+ return "CAP_SYS_TTY_CONFIG"
+ case CAP_MKNOD:
+ return "CAP_MKNOD"
+ case CAP_LEASE:
+ return "CAP_LEASE"
+ case CAP_AUDIT_WRITE:
+ return "CAP_AUDIT_WRITE"
+ case CAP_AUDIT_CONTROL:
+ return "CAP_AUDIT_CONTROL"
+ case CAP_SETFCAP:
+ return "CAP_SETFCAP"
+ case CAP_MAC_OVERRIDE:
+ return "CAP_MAC_OVERRIDE"
+ case CAP_MAC_ADMIN:
+ return "CAP_MAC_ADMIN"
+ case CAP_SYSLOG:
+ return "CAP_SYSLOG"
+ case CAP_WAKE_ALARM:
+ return "CAP_WAKE_ALARM"
+ case CAP_BLOCK_SUSPEND:
+ return "CAP_BLOCK_SUSPEND"
+ case CAP_AUDIT_READ:
+ return "CAP_AUDIT_READ"
+ default:
+ return "UNKNOWN"
+ }
+}
+
// Version numbers used by the capget/capset syscalls, defined in Linux's
// include/uapi/linux/capability.h.
const (
diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go
index 0b02f938a..cd043dac3 100644
--- a/pkg/abi/linux/mm.go
+++ b/pkg/abi/linux/mm.go
@@ -114,3 +114,12 @@ const (
MPOL_MODE_FLAGS = (MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES)
)
+
+// Flags for mbind(2).
+const (
+ MPOL_MF_STRICT = 1 << 0
+ MPOL_MF_MOVE = 1 << 1
+ MPOL_MF_MOVE_ALL = 1 << 2
+
+ MPOL_MF_VALID = MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL
+)
diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go
index 0428282dd..391cfaa1c 100644
--- a/pkg/abi/linux/prctl.go
+++ b/pkg/abi/linux/prctl.go
@@ -155,3 +155,10 @@ const (
ARCH_GET_GS = 0x1004
ARCH_SET_CPUID = 0x1012
)
+
+// Flags for prctl(PR_SET_DUMPABLE), defined in include/linux/sched/coredump.h.
+const (
+ SUID_DUMP_DISABLE = 0
+ SUID_DUMP_USER = 1
+ SUID_DUMP_ROOT = 2
+)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 417840731..a714ac86d 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -102,15 +102,19 @@ const (
SOL_NETLINK = 270
)
+// A SockType is a type (as opposed to family) of sockets. These are enumerated
+// below as SOCK_* constants.
+type SockType int
+
// Socket types, from linux/net.h.
const (
- SOCK_STREAM = 1
- SOCK_DGRAM = 2
- SOCK_RAW = 3
- SOCK_RDM = 4
- SOCK_SEQPACKET = 5
- SOCK_DCCP = 6
- SOCK_PACKET = 10
+ SOCK_STREAM SockType = 1
+ SOCK_DGRAM = 2
+ SOCK_RAW = 3
+ SOCK_RDM = 4
+ SOCK_SEQPACKET = 5
+ SOCK_DCCP = 6
+ SOCK_PACKET = 10
)
// SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are
@@ -200,6 +204,22 @@ const (
SS_DISCONNECTING = 4 // In process of disconnecting.
)
+// TCP protocol states, from include/net/tcp_states.h.
+const (
+ TCP_ESTABLISHED uint32 = iota + 1
+ TCP_SYN_SENT
+ TCP_SYN_RECV
+ TCP_FIN_WAIT1
+ TCP_FIN_WAIT2
+ TCP_TIME_WAIT
+ TCP_CLOSE
+ TCP_CLOSE_WAIT
+ TCP_LAST_ACK
+ TCP_LISTEN
+ TCP_CLOSING
+ TCP_NEW_SYN_RECV
+)
+
// SockAddrMax is the maximum size of a struct sockaddr, from
// uapi/linux/socket.h.
const SockAddrMax = 128
diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD
new file mode 100644
index 000000000..71b48a972
--- /dev/null
+++ b/pkg/memutil/BUILD
@@ -0,0 +1,11 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memutil",
+ srcs = ["memutil_unsafe.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/memutil",
+ visibility = ["//visibility:public"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/sentry/memutil/memutil_unsafe.go b/pkg/memutil/memutil_unsafe.go
index 92eab8a26..979d942a9 100644
--- a/pkg/sentry/memutil/memutil_unsafe.go
+++ b/pkg/memutil/memutil_unsafe.go
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build linux
+
+// Package memutil provides a wrapper for the memfd_create() system call.
package memutil
import (
diff --git a/pkg/sentry/platform/procid/BUILD b/pkg/procid/BUILD
index 277509624..7c22b763a 100644
--- a/pkg/sentry/platform/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -9,8 +9,8 @@ go_library(
"procid_amd64.s",
"procid_arm64.s",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid",
- visibility = ["//pkg/sentry:internal"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/procid",
+ visibility = ["//visibility:public"],
)
go_test(
diff --git a/pkg/sentry/platform/procid/procid.go b/pkg/procid/procid.go
index 78b92422c..78b92422c 100644
--- a/pkg/sentry/platform/procid/procid.go
+++ b/pkg/procid/procid.go
diff --git a/pkg/sentry/platform/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index 30ec8e6e2..30ec8e6e2 100644
--- a/pkg/sentry/platform/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
diff --git a/pkg/sentry/platform/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index e340d9f98..e340d9f98 100644
--- a/pkg/sentry/platform/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
diff --git a/pkg/sentry/platform/procid/procid_net_test.go b/pkg/procid/procid_net_test.go
index b628e2285..b628e2285 100644
--- a/pkg/sentry/platform/procid/procid_net_test.go
+++ b/pkg/procid/procid_net_test.go
diff --git a/pkg/sentry/platform/procid/procid_test.go b/pkg/procid/procid_test.go
index 88dd0b3ae..88dd0b3ae 100644
--- a/pkg/sentry/platform/procid/procid_test.go
+++ b/pkg/procid/procid_test.go
diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD
index ce4f1e42c..d17b1bdcf 100644
--- a/pkg/sentry/context/contexttest/BUILD
+++ b/pkg/sentry/context/contexttest/BUILD
@@ -9,11 +9,11 @@ go_library(
importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest",
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/memutil",
"//pkg/sentry/context",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
- "//pkg/sentry/memutil",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
"//pkg/sentry/platform/ptrace",
diff --git a/pkg/sentry/context/contexttest/contexttest.go b/pkg/sentry/context/contexttest/contexttest.go
index 210a235d2..83da40711 100644
--- a/pkg/sentry/context/contexttest/contexttest.go
+++ b/pkg/sentry/context/contexttest/contexttest.go
@@ -21,11 +21,11 @@ import (
"testing"
"time"
+ "gvisor.googlesource.com/gvisor/pkg/memutil"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/limits"
- "gvisor.googlesource.com/gvisor/pkg/sentry/memutil"
"gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform/ptrace"
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index c0bc261a2..a0a35c242 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -805,7 +805,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans
var childDir *Dirent
err := d.genericCreate(ctx, root, name, func() error {
var e error
- childDir, e = d.Inode.Bind(ctx, name, data, perms)
+ childDir, e = d.Inode.Bind(ctx, d, name, data, perms)
if e != nil {
return e
}
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 8c1307235..f64954457 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -545,12 +545,28 @@ type lockedWriter struct {
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
- n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), w.File.offset)
- return int(n), err
+ return w.WriteAt(buf, w.File.offset)
}
// WriteAt implements io.Writer.WriteAt.
func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
- n, err := w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf), offset)
- return int(n), err
+ var (
+ written int
+ err error
+ )
+ // The io.Writer contract requires that Write writes all available
+ // bytes and does not return short writes. This causes errors with
+ // io.Copy, since our own Write interface does not have this same
+ // contract. Enforce that here.
+ for written < len(buf) {
+ var n int64
+ n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
+ if n > 0 {
+ written += int(n)
+ }
+ if err != nil {
+ break
+ }
+ }
+ return written, err
}
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index cbd5b9a84..7ac0a421f 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -15,6 +15,7 @@
package gofer
import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/p9"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
@@ -61,13 +62,13 @@ type endpoint struct {
path string
}
-func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) {
+func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
switch t {
- case transport.SockStream:
+ case linux.SOCK_STREAM:
return p9.StreamSocket, true
- case transport.SockSeqpacket:
+ case linux.SOCK_SEQPACKET:
return p9.SeqpacketSocket, true
- case transport.SockDgram:
+ case linux.SOCK_DGRAM:
return p9.DgramSocket, true
}
return 0, false
@@ -75,7 +76,7 @@ func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) {
// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
- cf, ok := unixSockToP9(ce.Type())
+ cf, ok := sockTypeToP9(ce.Type())
if !ok {
return syserr.ErrConnectionRefused
}
@@ -139,3 +140,8 @@ func (e *endpoint) UnidirectionalConnect() (transport.ConnectedEndpoint, *syserr
func (e *endpoint) Release() {
e.inode.DecRef()
}
+
+// Passcred implements transport.BoundEndpoint.Passcred.
+func (e *endpoint) Passcred() bool {
+ return false
+}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 3ed137006..305eea718 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -15,9 +15,11 @@
package host
import (
+ "fmt"
"sync"
"syscall"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/fd"
"gvisor.googlesource.com/gvisor/pkg/fdnotifier"
"gvisor.googlesource.com/gvisor/pkg/log"
@@ -51,25 +53,11 @@ type ConnectedEndpoint struct {
// ref keeps track of references to a connectedEndpoint.
ref refs.AtomicRefCount
- // mu protects fd, readClosed and writeClosed.
- mu sync.RWMutex `state:"nosave"`
-
- // file is an *fd.FD containing the FD backing this endpoint. It must be
- // set to nil if it has been closed.
- file *fd.FD `state:"nosave"`
-
- // readClosed is true if the FD has read shutdown or if it has been closed.
- readClosed bool
-
- // writeClosed is true if the FD has write shutdown or if it has been
- // closed.
- writeClosed bool
-
// If srfd >= 0, it is the host FD that file was imported from.
srfd int `state:"wait"`
// stype is the type of Unix socket.
- stype transport.SockType
+ stype linux.SockType
// sndbuf is the size of the send buffer.
//
@@ -78,6 +66,13 @@ type ConnectedEndpoint struct {
// prevent lots of small messages from filling the real send buffer
// size on the host.
sndbuf int `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD `state:"nosave"`
}
// init performs initialization required for creating new ConnectedEndpoints and
@@ -111,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
return syserr.ErrInvalidEndpointState
}
- c.stype = transport.SockType(stype)
+ c.stype = linux.SockType(stype)
c.sndbuf = sndbuf
return nil
@@ -169,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil
+ return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil
}
// newSocket allocates a new unix socket with host endpoint.
@@ -201,16 +196,13 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil
+ return unixsocket.New(ctx, ep, e.stype), nil
}
// Send implements transport.ConnectedEndpoint.Send.
func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return 0, false, syserr.ErrClosedForSend
- }
if !controlMessages.Empty() {
return 0, false, syserr.ErrInvalidEndpointState
@@ -218,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro
// Since stream sockets don't preserve message boundaries, we can write
// only as much of the message as fits in the send buffer.
- truncate := c.stype == transport.SockStream
+ truncate := c.stype == linux.SOCK_STREAM
n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate)
if n < totalLen && err == nil {
@@ -244,8 +236,13 @@ func (c *ConnectedEndpoint) SendNotify() {}
// CloseSend implements transport.ConnectedEndpoint.CloseSend.
func (c *ConnectedEndpoint) CloseSend() {
c.mu.Lock()
- c.writeClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
}
// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
@@ -255,9 +252,7 @@ func (c *ConnectedEndpoint) CloseNotify() {}
func (c *ConnectedEndpoint) Writable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
}
@@ -285,9 +280,6 @@ func (c *ConnectedEndpoint) EventUpdate() {
func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
- }
var cm unet.ControlMessage
if numRights > 0 {
@@ -344,31 +336,34 @@ func (c *ConnectedEndpoint) RecvNotify() {}
// CloseRecv implements transport.Receiver.CloseRecv.
func (c *ConnectedEndpoint) CloseRecv() {
c.mu.Lock()
- c.readClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
}
// Readable implements transport.Receiver.Readable.
func (c *ConnectedEndpoint) Readable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
}
// SendQueuedSize implements transport.Receiver.SendQueuedSize.
func (c *ConnectedEndpoint) SendQueuedSize() int64 {
- // SendQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}
// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
- // RecvQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index 06392a65a..bc3ce5627 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -198,20 +198,6 @@ func TestListen(t *testing.T) {
}
}
-func TestSend(t *testing.T) {
- e := ConnectedEndpoint{writeClosed: true}
- if _, _, err := e.Send(nil, transport.ControlMessages{}, tcpip.FullAddress{}); err != syserr.ErrClosedForSend {
- t.Errorf("Got %#v.Send() = %v, want = %v", e, err, syserr.ErrClosedForSend)
- }
-}
-
-func TestRecv(t *testing.T) {
- e := ConnectedEndpoint{readClosed: true}
- if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive {
- t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive)
- }
-}
-
func TestPasscred(t *testing.T) {
e := ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
@@ -244,20 +230,6 @@ func TestQueuedSize(t *testing.T) {
}
}
-func TestReadable(t *testing.T) {
- e := ConnectedEndpoint{readClosed: true}
- if got, want := e.Readable(), true; got != want {
- t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want)
- }
-}
-
-func TestWritable(t *testing.T) {
- e := ConnectedEndpoint{writeClosed: true}
- if got, want := e.Writable(), true; got != want {
- t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want)
- }
-}
-
func TestRelease(t *testing.T) {
f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
@@ -272,131 +244,3 @@ func TestRelease(t *testing.T) {
t.Errorf("got = %#v, want = %#v", c, want)
}
}
-
-func TestClose(t *testing.T) {
- type testCase struct {
- name string
- cep *ConnectedEndpoint
- addFD bool
- f func()
- want *ConnectedEndpoint
- }
-
- var tests []testCase
-
- // nil is the value used by ConnectedEndpoint to indicate a closed file.
- // Non-nil files are used to check if the file gets closed.
-
- f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
- tests = append(tests, testCase{
- name: "First CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
- tests = append(tests, testCase{
- name: "Second CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
- tests = append(tests, testCase{
- name: "First CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
- tests = append(tests, testCase{
- name: "Second CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true}
- tests = append(tests, testCase{
- name: "CloseSend then CloseRecv",
- cep: c,
- addFD: true,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true}
- tests = append(tests, testCase{
- name: "CloseRecv then CloseSend",
- cep: c,
- addFD: true,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
- tests = append(tests, testCase{
- name: "Full close then CloseRecv",
- cep: c,
- addFD: false,
- f: c.CloseRecv,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
- if err != nil {
- t.Fatal("Creating socket:", err)
- }
- c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true}
- tests = append(tests, testCase{
- name: "Full close then CloseSend",
- cep: c,
- addFD: false,
- f: c.CloseSend,
- want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true},
- })
-
- for _, test := range tests {
- if test.addFD {
- fdnotifier.AddFD(int32(test.cep.file.FD()), nil)
- }
- if test.f(); !reflect.DeepEqual(test.cep, test.want) {
- t.Errorf("%s: got = %#v, want = %#v", test.name, test.cep, test.want)
- }
- }
-}
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index aef1a1cb9..0b54c2e77 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -220,9 +220,9 @@ func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent,
}
// Bind calls i.InodeOperations.Bind with i as the directory.
-func (i *Inode) Bind(ctx context.Context, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+func (i *Inode) Bind(ctx context.Context, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
if i.overlay != nil {
- return overlayBind(ctx, i.overlay, name, data, perm)
+ return overlayBind(ctx, i.overlay, parent, name, data, perm)
}
return i.InodeOperations.Bind(ctx, i, name, data, perm)
}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index cdffe173b..06506fb20 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -398,14 +398,14 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
return nil
}
-func overlayBind(ctx context.Context, o *overlayEntry, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
+ if err := copyUp(ctx, parent); err != nil {
+ return nil, err
+ }
+
o.copyMu.RLock()
defer o.copyMu.RUnlock()
- // We do not support doing anything exciting with sockets unless there
- // is already a directory in the upper filesystem.
- if o.upper == nil {
- return nil, syserror.EOPNOTSUPP
- }
+
d, err := o.upper.InodeOperations.Bind(ctx, o.upper, name, data, perm)
if err != nil {
return nil, err
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index d19c360e0..1728fe0b5 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -45,6 +45,7 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/rpcinet",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
index 379569823..986bc0a45 100644
--- a/pkg/sentry/fs/proc/inode.go
+++ b/pkg/sentry/fs/proc/inode.go
@@ -21,11 +21,14 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
)
// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
-// method to return the task as the owner.
+// method to return either the task or root as the owner, depending on the
+// task's dumpability.
//
// +stateify savable
type taskOwnedInodeOps struct {
@@ -41,9 +44,42 @@ func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) (
if err != nil {
return fs.UnstableAttr{}, err
}
- // Set the task owner as the file owner.
+
+ // By default, set the task owner as the file owner.
creds := i.t.Credentials()
uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID}
+
+ // Linux doesn't apply dumpability adjustments to world
+ // readable/executable directories so that applications can stat
+ // /proc/PID to determine the effective UID of a process. See
+ // fs/proc/base.c:task_dump_owner.
+ if fs.IsDir(inode.StableAttr) && uattr.Perms == fs.FilePermsFromMode(0555) {
+ return uattr, nil
+ }
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ var m *mm.MemoryManager
+ i.t.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+
+ if m == nil {
+ uattr.Owner.UID = auth.RootKUID
+ uattr.Owner.GID = auth.RootKGID
+ } else if m.Dumpability() != mm.UserDumpable {
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uattr.Owner.UID = kuid
+ } else {
+ uattr.Owner.UID = auth.RootKUID
+ }
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ uattr.Owner.GID = kgid
+ } else {
+ uattr.Owner.GID = auth.RootKGID
+ }
+ }
+
return uattr, nil
}
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 4a107c739..034950158 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -27,6 +27,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
"gvisor.googlesource.com/gvisor/pkg/sentry/inet"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
)
@@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n")
// Entries
- for _, sref := range n.k.ListSockets(linux.AF_UNIX) {
- s := sref.Get()
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref)
+ log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
continue
}
sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(*unix.SocketOperations)
- if !ok {
- panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile))
+ if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ // Not a unix socket.
+ continue
}
+ sops := sfile.FileOperations.(*unix.SocketOperations)
addr, err := sops.Endpoint().GetLocalAddress()
if err != nil {
@@ -240,24 +242,6 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
}
}
- var sockState int
- switch sops.Endpoint().Type() {
- case linux.SOCK_DGRAM:
- sockState = linux.SS_CONNECTING
- // Unlike Linux, we don't have unbound connection-less sockets,
- // so no SS_DISCONNECTING.
-
- case linux.SOCK_SEQPACKET:
- fallthrough
- case linux.SOCK_STREAM:
- // Connectioned.
- if sops.Endpoint().(transport.ConnectingEndpoint).Connected() {
- sockState = linux.SS_CONNECTED
- } else {
- sockState = linux.SS_UNCONNECTED
- }
- }
-
// In the socket entry below, the value for the 'Num' field requires
// some consideration. Linux prints the address to the struct
// unix_sock representing a socket in the kernel, but may redact the
@@ -282,7 +266,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
0, // Protocol, always 0 for UDS.
sockFlags, // Flags.
sops.Endpoint().Type(), // Type.
- sockState, // State.
+ sops.State(), // State.
sfile.InodeID(), // Inode.
)
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 77e03d349..21a965f90 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -96,7 +96,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks boo
contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
}
- // TODO(b/31916171): Set EUID/EGID based on dumpability.
+ // N.B. taskOwnedInodeOps enforces dumpability-based ownership.
d := &taskDir{
Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)),
t: t,
@@ -667,6 +667,21 @@ func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(c, msrc, fs.SpecialFile, t)
}
+// Check implements fs.InodeOperations.Check.
+func (c *comm) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool {
+ // This file can always be read or written by members of the same
+ // thread group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing
+ // and this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == c.t.ThreadGroup() && !p.Execute {
+ return true
+ }
+
+ return fs.ContextCanAccessFile(ctx, inode, p)
+}
+
// GetFile implements fs.InodeOperations.GetFile.
func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index bce5f091d..c1721f434 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -54,6 +54,8 @@ type TimerOperations struct {
// NewFile returns a timerfd File that receives time from c.
func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[timerfd]")
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
tops := &TimerOperations{}
tops.timer = ktime.NewTimer(c, tops)
// Timerfds reject writes, but the Write flag must be set in order to
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index b7c29a4d1..83e1bf247 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -34,6 +34,16 @@ const (
// GID for the root directory.
rootGIDKey = "gid"
+ // cacheKey sets the caching policy for the mount.
+ cacheKey = "cache"
+
+ // cacheAll uses the virtual file system cache for everything (default).
+ cacheAll = "cache"
+
+ // cacheRevalidate allows dirents to be cached, but revalidates them on each
+ // lookup.
+ cacheRevalidate = "revalidate"
+
// TODO(edahlgren/mpratt): support a tmpfs size limit.
// size = "size"
@@ -122,15 +132,24 @@ func (f *Filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
delete(options, rootGIDKey)
}
+ // Construct a mount which will follow the cache options provided.
+ var msrc *fs.MountSource
+ switch options[cacheKey] {
+ case "", cacheAll:
+ msrc = fs.NewCachingMountSource(f, flags)
+ case cacheRevalidate:
+ msrc = fs.NewRevalidatingMountSource(f, flags)
+ default:
+ return nil, fmt.Errorf("invalid cache policy option %q", options[cacheKey])
+ }
+ delete(options, cacheKey)
+
// Fail if the caller passed us more options than we can parse. They may be
// expecting us to set something we can't set.
if len(options) > 0 {
return nil, fmt.Errorf("unsupported mount options: %v", options)
}
- // Construct a mount which will cache dirents.
- msrc := fs.NewCachingMountSource(f, flags)
-
// Construct the tmpfs root.
return NewDir(ctx, nil, owner, perms, msrc), nil
}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
new file mode 100644
index 000000000..1a4632a54
--- /dev/null
+++ b/pkg/sentry/hostmm/BUILD
@@ -0,0 +1,18 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hostmm",
+ srcs = [
+ "cgroup.go",
+ "hostmm.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/hostmm",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/sentry/usermem",
+ ],
+)
diff --git a/pkg/sentry/hostmm/cgroup.go b/pkg/sentry/hostmm/cgroup.go
new file mode 100644
index 000000000..e5cc26ab2
--- /dev/null
+++ b/pkg/sentry/hostmm/cgroup.go
@@ -0,0 +1,111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostmm
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "path"
+ "strings"
+)
+
+// currentCgroupDirectory returns the directory for the cgroup for the given
+// controller in which the calling process resides.
+func currentCgroupDirectory(ctrl string) (string, error) {
+ root, err := cgroupRootDirectory(ctrl)
+ if err != nil {
+ return "", err
+ }
+ cg, err := currentCgroup(ctrl)
+ if err != nil {
+ return "", err
+ }
+ return path.Join(root, cg), nil
+}
+
+// cgroupRootDirectory returns the root directory for the cgroup hierarchy in
+// which the given cgroup controller is mounted in the calling process' mount
+// namespace.
+func cgroupRootDirectory(ctrl string) (string, error) {
+ const path = "/proc/self/mounts"
+ file, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ // Per proc(5) -> fstab(5):
+ // Each line of /proc/self/mounts describes a mount.
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ // Each line consists of 6 space-separated fields. Find the line for
+ // which the third field (fs_vfstype) is cgroup, and the fourth field
+ // (fs_mntops, a comma-separated list of mount options) contains
+ // ctrl.
+ var spec, file, vfstype, mntopts, freq, passno string
+ const nrfields = 6
+ line := scanner.Text()
+ n, err := fmt.Sscan(line, &spec, &file, &vfstype, &mntopts, &freq, &passno)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse %s: %v", path, err)
+ }
+ if n != nrfields {
+ return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, n, nrfields)
+ }
+ if vfstype != "cgroup" {
+ continue
+ }
+ for _, mntopt := range strings.Split(mntopts, ",") {
+ if mntopt == ctrl {
+ return file, nil
+ }
+ }
+ }
+ return "", fmt.Errorf("no cgroup hierarchy mounted for controller %s", ctrl)
+}
+
+// currentCgroup returns the cgroup for the given controller in which the
+// calling process resides. The returned string is a path that should be
+// interpreted as relative to cgroupRootDirectory(ctrl).
+func currentCgroup(ctrl string) (string, error) {
+ const path = "/proc/self/cgroup"
+ file, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+ defer file.Close()
+
+ // Per proc(5) -> cgroups(7):
+ // Each line of /proc/self/cgroups describes a cgroup hierarchy.
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ // Each line consists of 3 colon-separated fields. Find the line for
+ // which the second field (controller-list, a comma-separated list of
+ // cgroup controllers) contains ctrl.
+ line := scanner.Text()
+ const nrfields = 3
+ fields := strings.Split(line, ":")
+ if len(fields) != nrfields {
+ return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, len(fields), nrfields)
+ }
+ for _, controller := range strings.Split(fields[1], ",") {
+ if controller == ctrl {
+ return fields[2], nil
+ }
+ }
+ }
+ return "", fmt.Errorf("not a member of a cgroup hierarchy for controller %s", ctrl)
+}
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
new file mode 100644
index 000000000..5432cada9
--- /dev/null
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostmm provides tools for interacting with the host Linux kernel's
+// virtual memory management subsystem.
+package hostmm
+
+import (
+ "fmt"
+ "os"
+ "path"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// NotifyCurrentMemcgPressureCallback requests that f is called whenever the
+// calling process' memory cgroup indicates memory pressure of the given level,
+// as specified by Linux's Documentation/cgroup-v1/memory.txt.
+//
+// If NotifyCurrentMemcgPressureCallback succeeds, it returns a function that
+// terminates the requested memory pressure notifications. This function may be
+// called at most once.
+func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) {
+ cgdir, err := currentCgroupDirectory("memory")
+ if err != nil {
+ return nil, err
+ }
+
+ pressurePath := path.Join(cgdir, "memory.pressure_level")
+ pressureFile, err := os.Open(pressurePath)
+ if err != nil {
+ return nil, err
+ }
+ defer pressureFile.Close()
+
+ eventControlPath := path.Join(cgdir, "cgroup.event_control")
+ eventControlFile, err := os.OpenFile(eventControlPath, os.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer eventControlFile.Close()
+
+ eventFD, err := newEventFD()
+ if err != nil {
+ return nil, err
+ }
+
+ // Don't use fmt.Fprintf since the whole string needs to be written in a
+ // single syscall.
+ eventControlStr := fmt.Sprintf("%d %d %s", eventFD.FD(), pressureFile.Fd(), level)
+ if n, err := eventControlFile.Write([]byte(eventControlStr)); n != len(eventControlStr) || err != nil {
+ eventFD.Close()
+ return nil, fmt.Errorf("error writing %q to %s: got (%d, %v), wanted (%d, nil)", eventControlStr, eventControlPath, n, err, len(eventControlStr))
+ }
+
+ log.Debugf("Receiving memory pressure level notifications from %s at level %q", pressurePath, level)
+ const sizeofUint64 = 8
+ // The most significant bit of the eventfd value is set by the stop
+ // function, which is practically unambiguous since it's not plausible for
+ // 2**63 pressure events to occur between eventfd reads.
+ const stopVal = 1 << 63
+ stopCh := make(chan struct{})
+ go func() { // S/R-SAFE: f provides synchronization if necessary
+ rw := fd.NewReadWriter(eventFD.FD())
+ var buf [sizeofUint64]byte
+ for {
+ n, err := rw.Read(buf[:])
+ if err != nil {
+ if err == syscall.EINTR {
+ continue
+ }
+ panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err))
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+ if val >= stopVal {
+ // Assume this was due to the notifier's "destructor" (the
+ // function returned by NotifyCurrentMemcgPressureCallback
+ // below) being called.
+ eventFD.Close()
+ close(stopCh)
+ return
+ }
+ f()
+ }
+ }()
+ return func() {
+ rw := fd.NewReadWriter(eventFD.FD())
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], stopVal)
+ for {
+ n, err := rw.Write(buf[:])
+ if err != nil {
+ if err == syscall.EINTR {
+ continue
+ }
+ panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err))
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ break
+ }
+ <-stopCh
+ }, nil
+}
+
+func newEventFD() (*fd.FD, error) {
+ f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if e != 0 {
+ return nil, fmt.Errorf("failed to create eventfd: %v", e)
+ }
+ return fd.New(int(f)), nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 99a2fd964..04e375910 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -64,6 +64,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "socket_list",
+ out = "socket_list.go",
+ package = "kernel",
+ prefix = "socket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SocketEntry",
+ "Linker": "*SocketEntry",
+ },
+)
+
proto_library(
name = "uncaught_signal_proto",
srcs = ["uncaught_signal.proto"],
@@ -104,6 +116,7 @@ go_library(
"sessions.go",
"signal.go",
"signal_handlers.go",
+ "socket_list.go",
"syscalls.go",
"syscalls_state.go",
"syslog.go",
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index bbacba1f4..43ae22a5d 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -156,6 +156,8 @@ var cycleMu sync.Mutex
func NewEventPoll(ctx context.Context) *fs.File {
// name matches fs/eventpoll.c:epoll_create1.
dirent := fs.NewDirent(anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]"))
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{
files: make(map[FileIdentifier]*pollEntry),
})
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 2f900be38..fe474cbf0 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -69,6 +69,8 @@ type EventOperations struct {
func New(ctx context.Context, initVal uint64, semMode bool) *fs.File {
// name matches fs/eventfd.c:eventfd_file_create.
dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[eventfd]")
+ // Release the initial dirent reference after NewFile takes a reference.
+ defer dirent.DecRef()
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{
val: initVal,
semMode: semMode,
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 85d73ace2..f253a81d9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -182,9 +182,13 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // socketTable is used to track all sockets on the system. Protected by
+ // sockets is the list of all network sockets the system. Protected by
// extMu.
- socketTable map[int]map[*refs.WeakRef]struct{}
+ sockets socketList
+
+ // nextSocketEntry is the next entry number to use in sockets. Protected
+ // by extMu.
+ nextSocketEntry uint64
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
- k.socketTable = make(map[int]map[*refs.WeakRef]struct{})
return nil
}
@@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
})
}
-// socketEntry represents a socket recorded in Kernel.socketTable. It implements
+// SocketEntry represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
// +stateify savable
-type socketEntry struct {
- k *Kernel
- sock *refs.WeakRef
- family int
+type SocketEntry struct {
+ socketEntry
+ k *Kernel
+ Sock *refs.WeakRef
+ ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *socketEntry) WeakRefGone() {
+func (s *SocketEntry) WeakRefGone() {
s.k.extMu.Lock()
- // k.socketTable is guaranteed to point to a valid socket table for s.family
- // at this point, since we made sure of the fact when we created this
- // socketEntry, and we never delete socket tables.
- delete(s.k.socketTable[s.family], s.sock)
+ s.k.sockets.Remove(s)
s.k.extMu.Unlock()
}
// RecordSocket adds a socket to the system-wide socket table for tracking.
//
// Precondition: Caller must hold a reference to sock.
-func (k *Kernel) RecordSocket(sock *fs.File, family int) {
+func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
- table, ok := k.socketTable[family]
- if !ok {
- table = make(map[*refs.WeakRef]struct{})
- k.socketTable[family] = table
- }
- se := socketEntry{k: k, family: family}
- se.sock = refs.NewWeakRef(sock, &se)
- table[se.sock] = struct{}{}
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{k: k, ID: id}
+ s.Sock = refs.NewWeakRef(sock, s)
+ k.sockets.PushBack(s)
k.extMu.Unlock()
}
-// ListSockets returns a snapshot of all sockets of a given family.
-func (k *Kernel) ListSockets(family int) []*refs.WeakRef {
+// ListSockets returns a snapshot of all sockets.
+func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
- socks := []*refs.WeakRef{}
- if table, ok := k.socketTable[family]; ok {
- socks = make([]*refs.WeakRef, 0, len(table))
- for s := range table {
- socks = append(socks, s)
- }
+ var socks []*SocketEntry
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, s)
}
k.extMu.Unlock()
return socks
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 926c4c623..dc7da529e 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -38,7 +38,11 @@ type inodeOperations struct {
fsutil.InodeNotMappable `state:"nosave"`
fsutil.InodeNotSocket `state:"nosave"`
fsutil.InodeNotSymlink `state:"nosave"`
- fsutil.InodeNotVirtual `state:"nosave"`
+
+ // Marking pipe inodes as virtual allows them to be saved and restored
+ // even if they have been unlinked. We can get away with this because
+ // their state exists entirely within the sentry.
+ fsutil.InodeVirtual `state:"nosave"`
fsutil.InodeSimpleAttributes
@@ -86,7 +90,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
switch {
case flags.Read && !flags.Write: // O_RDONLY.
- r := i.p.Open(ctx, flags)
+ r := i.p.Open(ctx, d, flags)
i.newHandleLocked(&i.rWakeup)
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
@@ -102,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
return r, nil
case flags.Write && !flags.Read: // O_WRONLY.
- w := i.p.Open(ctx, flags)
+ w := i.p.Open(ctx, d, flags)
i.newHandleLocked(&i.wWakeup)
if i.p.isNamed && !i.p.HasReaders() {
@@ -122,7 +126,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Read && flags.Write: // O_RDWR.
// Pipes opened for read-write always succeeds without blocking.
- rw := i.p.Open(ctx, flags)
+ rw := i.p.Open(ctx, d, flags)
i.newHandleLocked(&i.rWakeup)
i.newHandleLocked(&i.wWakeup)
return rw, nil
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index 31d9b0443..9a946b380 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -62,7 +62,9 @@ var perms fs.FilePermissions = fs.FilePermissions{
}
func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) {
- file, err := n.GetFile(ctx, nil, flags)
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
+ d := fs.NewDirent(inode, "pipe")
+ file, err := n.GetFile(ctx, d, flags)
if err != nil {
t.Fatalf("open with flags %+v failed: %v", flags, err)
}
@@ -73,7 +75,9 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag
}
func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) {
- file, err := n.GetFile(ctx, nil, flags)
+ inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
+ d := fs.NewDirent(inode, "pipe")
+ file, err := n.GetFile(ctx, d, flags)
if resChan != nil {
resChan <- openResult{file, err}
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index b65204492..73438dc62 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -71,11 +71,6 @@ type Pipe struct {
// This value is immutable.
atomicIOBytes int64
- // The dirent backing this pipe. Shared by all readers and writers.
- //
- // This value is immutable.
- Dirent *fs.Dirent
-
// The number of active readers for this pipe.
//
// Access atomically.
@@ -130,14 +125,20 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64)
if atomicIOBytes > sizeBytes {
atomicIOBytes = sizeBytes
}
- p := &Pipe{
+ return &Pipe{
isNamed: isNamed,
max: sizeBytes,
atomicIOBytes: atomicIOBytes,
}
+}
- // Build the fs.Dirent of this pipe, shared by all fs.Files associated
- // with this pipe.
+// NewConnectedPipe initializes a pipe and returns a pair of objects
+// representing the read and write ends of the pipe.
+func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes)
+
+ // Build an fs.Dirent for the pipe which will be shared by both
+ // returned files.
perms := fs.FilePermissions{
User: fs.PermMask{Read: true, Write: true},
}
@@ -150,36 +151,32 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64)
BlockSize: int64(atomicIOBytes),
}
ms := fs.NewPseudoMountSource()
- p.Dirent = fs.NewDirent(fs.NewInode(iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
- return p
-}
-
-// NewConnectedPipe initializes a pipe and returns a pair of objects
-// representing the read and write ends of the pipe.
-func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
- p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes)
- return p.Open(ctx, fs.FileFlags{Read: true}), p.Open(ctx, fs.FileFlags{Write: true})
+ d := fs.NewDirent(fs.NewInode(iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
+ // The p.Open calls below will each take a reference on the Dirent. We
+ // must drop the one we already have.
+ defer d.DecRef()
+ return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true})
}
// Open opens the pipe and returns a new file.
//
// Precondition: at least one of flags.Read or flags.Write must be set.
-func (p *Pipe) Open(ctx context.Context, flags fs.FileFlags) *fs.File {
+func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.File {
switch {
case flags.Read && flags.Write:
p.rOpen()
p.wOpen()
- return fs.NewFile(ctx, p.Dirent, flags, &ReaderWriter{
+ return fs.NewFile(ctx, d, flags, &ReaderWriter{
Pipe: p,
})
case flags.Read:
p.rOpen()
- return fs.NewFile(ctx, p.Dirent, flags, &Reader{
+ return fs.NewFile(ctx, d, flags, &Reader{
ReaderWriter: ReaderWriter{Pipe: p},
})
case flags.Write:
p.wOpen()
- return fs.NewFile(ctx, p.Dirent, flags, &Writer{
+ return fs.NewFile(ctx, d, flags, &Writer{
ReaderWriter: ReaderWriter{Pipe: p},
})
default:
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 4423e7efd..193447b17 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -19,6 +19,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
@@ -92,6 +93,14 @@ const (
// ptrace(2), subsection "Ptrace access mode checking". If attach is true, it
// checks for access mode PTRACE_MODE_ATTACH; otherwise, it checks for access
// mode PTRACE_MODE_READ.
+//
+// NOTE(b/30815691): The result of CanTrace is immediately stale (e.g., a
+// racing setuid(2) may change traceability). This may pose a risk when a task
+// changes from traceable to not traceable. This is only problematic across
+// execve, where privileges may increase.
+//
+// We currently do not implement privileged executables (set-user/group-ID bits
+// and file capabilities), so that case is not reachable.
func (t *Task) CanTrace(target *Task, attach bool) bool {
// "1. If the calling thread and the target thread are in the same thread
// group, access is always allowed." - ptrace(2)
@@ -162,7 +171,13 @@ func (t *Task) CanTrace(target *Task, attach bool) bool {
if cgid := callerCreds.RealKGID; cgid != targetCreds.RealKGID || cgid != targetCreds.EffectiveKGID || cgid != targetCreds.SavedKGID {
return false
}
- // TODO(b/31916171): dumpability check
+ var targetMM *mm.MemoryManager
+ target.WithMuLocked(func(t *Task) {
+ targetMM = t.MemoryManager()
+ })
+ if targetMM != nil && targetMM.Dumpability() != mm.UserDumpable {
+ return false
+ }
if callerCreds.UserNamespace != targetCreds.UserNamespace {
return false
}
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 0572053db..27cd3728b 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -32,6 +32,51 @@ import (
// syscall.
const maxSyscallNum = 2000
+// SyscallSupportLevel is a syscall support levels.
+type SyscallSupportLevel int
+
+// String returns a human readable represetation of the support level.
+func (l SyscallSupportLevel) String() string {
+ switch l {
+ case SupportUnimplemented:
+ return "Unimplemented"
+ case SupportPartial:
+ return "Partial Support"
+ case SupportFull:
+ return "Full Support"
+ default:
+ return "Undocumented"
+ }
+}
+
+const (
+ // SupportUndocumented indicates the syscall is not documented yet.
+ SupportUndocumented = iota
+
+ // SupportUnimplemented indicates the syscall is unimplemented.
+ SupportUnimplemented
+
+ // SupportPartial indicates the syscall is partially supported.
+ SupportPartial
+
+ // SupportFull indicates the syscall is fully supported.
+ SupportFull
+)
+
+// Syscall includes the syscall implementation and compatibility information.
+type Syscall struct {
+ // Name is the syscall name.
+ Name string
+ // Fn is the implementation of the syscall.
+ Fn SyscallFn
+ // SupportLevel is the level of support implemented in gVisor.
+ SupportLevel SyscallSupportLevel
+ // Note describes the compatibility of the syscall.
+ Note string
+ // URLs is set of URLs to any relevant bugs or issues.
+ URLs []string
+}
+
// SyscallFn is a syscall implementation.
type SyscallFn func(t *Task, args arch.SyscallArguments) (uintptr, *SyscallControl, error)
@@ -83,7 +128,7 @@ type SyscallFlagsTable struct {
// Init initializes the struct, with all syscalls in table set to enable.
//
// max is the largest syscall number in table.
-func (e *SyscallFlagsTable) init(table map[uintptr]SyscallFn, max uintptr) {
+func (e *SyscallFlagsTable) init(table map[uintptr]Syscall, max uintptr) {
e.enable = make([]uint32, max+1)
for num := range table {
e.enable[num] = syscallPresent
@@ -194,7 +239,7 @@ type SyscallTable struct {
AuditNumber uint32 `state:"manual"`
// Table is the collection of functions.
- Table map[uintptr]SyscallFn `state:"manual"`
+ Table map[uintptr]Syscall `state:"manual"`
// lookup is a fixed-size array that holds the syscalls (indexed by
// their numbers). It is used for fast look ups.
@@ -247,7 +292,7 @@ func LookupSyscallTable(os abi.OS, a arch.Arch) (*SyscallTable, bool) {
func RegisterSyscallTable(s *SyscallTable) {
if s.Table == nil {
// Ensure non-nil lookup table.
- s.Table = make(map[uintptr]SyscallFn)
+ s.Table = make(map[uintptr]Syscall)
}
if s.Emulate == nil {
// Ensure non-nil emulate table.
@@ -268,8 +313,8 @@ func RegisterSyscallTable(s *SyscallTable) {
s.lookup = make([]SyscallFn, max+1)
// Initialize the fast-lookup table.
- for num, fn := range s.Table {
- s.lookup[num] = fn
+ for num, sc := range s.Table {
+ s.lookup[num] = sc.Fn
}
s.FeatureEnable.init(s.Table, max)
@@ -303,5 +348,8 @@ func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
// mapLookup is similar to Lookup, except that it only uses the syscall table,
// that is, it skips the fast look array. This is available for benchmarking.
func (s *SyscallTable) mapLookup(sysno uintptr) SyscallFn {
- return s.Table[sysno]
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Fn
+ }
+ return nil
}
diff --git a/pkg/sentry/kernel/table_test.go b/pkg/sentry/kernel/table_test.go
index 8f7cdb9f3..3f2b042c8 100644
--- a/pkg/sentry/kernel/table_test.go
+++ b/pkg/sentry/kernel/table_test.go
@@ -26,11 +26,13 @@ const (
)
func createSyscallTable() *SyscallTable {
- m := make(map[uintptr]SyscallFn)
+ m := make(map[uintptr]Syscall)
for i := uintptr(0); i <= maxTestSyscall; i++ {
j := i
- m[i] = func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) {
- return j, nil, nil
+ m[i] = Syscall{
+ Fn: func(*Task, arch.SyscallArguments) (uintptr, *SyscallControl, error) {
+ return j, nil, nil
+ },
}
}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index f9378c2de..4d889422f 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -455,12 +455,13 @@ type Task struct {
// single numa node, all policies are no-ops. We only track this information
// so that we can return reasonable values if the application calls
// get_mempolicy(2) after setting a non-default policy. Note that in the
- // real syscall, nodemask can be longer than 4 bytes, but we always report a
- // single node so never need to save more than a single bit.
+ // real syscall, nodemask can be longer than a single unsigned long, but we
+ // always report a single node so never need to save more than a single
+ // bit.
//
// numaPolicy and numaNodeMask are protected by mu.
numaPolicy int32
- numaNodeMask uint32
+ numaNodeMask uint64
// If netns is true, the task is in a non-root network namespace. Network
// namespaces aren't currently implemented in full; being in a network
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 5d1425d5c..35d5cb90c 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -68,6 +68,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
@@ -198,6 +199,12 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
return flags.CloseOnExec
})
+ // NOTE(b/30815691): We currently do not implement privileged
+ // executables (set-user/group-ID bits and file capabilities). This
+ // allows us to unconditionally enable user dumpability on the new mm.
+ // See fs/exec.c:setup_new_exec.
+ r.tc.MemoryManager.SetDumpability(mm.UserDumpable)
+
// Switch to the new process.
t.MemoryManager().Deactivate()
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 17f08729a..ec95f78d0 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
@@ -206,8 +207,17 @@ func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) {
// (filesystem UIDs aren't implemented, nor are any of the capabilities in
// question)
- // Not documented, but compare Linux's kernel/cred.c:commit_creds().
if oldE != newE {
+ // "[dumpability] is reset to the current value contained in
+ // the file /proc/sys/fs/suid_dumpable (which by default has
+ // the value 0), in the following circumstances: The process's
+ // effective user or group ID is changed." - prctl(2)
+ //
+ // (suid_dumpable isn't implemented, so we just use the
+ // default.
+ t.MemoryManager().SetDumpability(mm.NotDumpable)
+
+ // Not documented, but compare Linux's kernel/cred.c:commit_creds().
t.parentDeathSignal = 0
}
}
@@ -303,8 +313,18 @@ func (t *Task) setKGIDsUncheckedLocked(newR, newE, newS auth.KGID) {
t.creds = t.creds.Fork() // See doc for creds.
t.creds.RealKGID, t.creds.EffectiveKGID, t.creds.SavedKGID = newR, newE, newS
- // Not documented, but compare Linux's kernel/cred.c:commit_creds().
if oldE != newE {
+ // "[dumpability] is reset to the current value contained in
+ // the file /proc/sys/fs/suid_dumpable (which by default has
+ // the value 0), in the following circumstances: The process's
+ // effective user or group ID is changed." - prctl(2)
+ //
+ // (suid_dumpable isn't implemented, so we just use the
+ // default.
+ t.MemoryManager().SetDumpability(mm.NotDumpable)
+
+ // Not documented, but compare Linux's
+ // kernel/cred.c:commit_creds().
t.parentDeathSignal = 0
}
}
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index 5455f6ea9..1c94ab11b 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -622,14 +622,14 @@ func (t *Task) SetNiceness(n int) {
}
// NumaPolicy returns t's current numa policy.
-func (t *Task) NumaPolicy() (policy int32, nodeMask uint32) {
+func (t *Task) NumaPolicy() (policy int32, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
return t.numaPolicy, t.numaNodeMask
}
// SetNumaPolicy sets t's numa policy.
-func (t *Task) SetNumaPolicy(policy int32, nodeMask uint32) {
+func (t *Task) SetNumaPolicy(policy int32, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
t.numaPolicy = policy
diff --git a/pkg/sentry/memutil/BUILD b/pkg/sentry/memutil/BUILD
deleted file mode 100644
index 68b03d4cc..000000000
--- a/pkg/sentry/memutil/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "memutil",
- srcs = [
- "memutil.go",
- "memutil_unsafe.go",
- ],
- importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/memutil",
- visibility = ["//pkg/sentry:internal"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
-)
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 7a65a62a2..7646d5ab2 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -37,6 +37,7 @@ func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *Memo
privateRefs: &privateRefs{},
users: 1,
auxv: arch.Auxv{},
+ dumpability: UserDumpable,
aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
}
}
@@ -79,8 +80,9 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
}
// Copy vmas.
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index 9768e51f1..c218006ee 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -20,6 +20,36 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
)
+// Dumpability describes if and how core dumps should be created.
+type Dumpability int
+
+const (
+ // NotDumpable indicates that core dumps should never be created.
+ NotDumpable Dumpability = iota
+
+ // UserDumpable indicates that core dumps should be created, owned by
+ // the current user.
+ UserDumpable
+
+ // RootDumpable indicates that core dumps should be created, owned by
+ // root.
+ RootDumpable
+)
+
+// Dumpability returns the dumpability.
+func (mm *MemoryManager) Dumpability() Dumpability {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.dumpability
+}
+
+// SetDumpability sets the dumpability.
+func (mm *MemoryManager) SetDumpability(d Dumpability) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.dumpability = d
+}
+
// ArgvStart returns the start of the application argument vector.
//
// There is no guarantee that this value is sensible w.r.t. ArgvEnd.
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index eb6defa2b..604866d04 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -219,6 +219,12 @@ type MemoryManager struct {
// executable is protected by metadataMu.
executable *fs.Dirent
+ // dumpability describes if and how this MemoryManager may be dumped to
+ // userspace.
+ //
+ // dumpability is protected by metadataMu.
+ dumpability Dumpability
+
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
@@ -270,6 +276,12 @@ type vma struct {
mlockMode memmap.MLockMode
+ // numaPolicy is the NUMA policy for this vma set by mbind().
+ numaPolicy int32
+
+ // numaNodemask is the NUMA nodemask for this vma set by mbind().
+ numaNodemask uint64
+
// If id is not nil, it controls the lifecycle of mappable and provides vma
// metadata shown in /proc/[pid]/maps, and the vma holds a reference.
id memmap.MappingIdentity
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 0368c6794..9cf136532 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -470,6 +470,16 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr usermem.Addr, oldSi
return 0, syserror.EINVAL
}
+ // Check that the new region is valid.
+ _, err := mm.findAvailableLocked(newSize, findAvailableOpts{
+ Addr: newAddr,
+ Fixed: true,
+ Unmap: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
// Unmap any mappings at the destination.
mm.unmapLocked(ctx, newAR)
@@ -963,6 +973,59 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
return nil
}
+// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR).
+func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ vseg := mm.vmas.FindSegment(addr)
+ if !vseg.Ok() {
+ return 0, 0, syserror.EFAULT
+ }
+ vma := vseg.ValuePtr()
+ return vma.numaPolicy, vma.numaNodemask, nil
+}
+
+// SetNumaPolicy implements the semantics of Linux's mbind().
+func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy int32, nodemask uint64) error {
+ if !addr.IsPageAligned() {
+ return syserror.EINVAL
+ }
+ // Linux allows this to overflow.
+ la, _ := usermem.Addr(length).RoundUp()
+ ar, ok := addr.ToRange(uint64(la))
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ar.Length() == 0 {
+ return nil
+ }
+
+ mm.mappingMu.Lock()
+ defer mm.mappingMu.Unlock()
+ defer func() {
+ mm.vmas.MergeRange(ar)
+ mm.vmas.MergeAdjacent(ar)
+ }()
+ vseg := mm.vmas.LowerBoundSegment(ar.Start)
+ lastEnd := ar.Start
+ for {
+ if !vseg.Ok() || lastEnd < vseg.Start() {
+ // "EFAULT: ... there was an unmapped hole in the specified memory
+ // range specified [sic] by addr and len." - mbind(2)
+ return syserror.EFAULT
+ }
+ vseg = mm.vmas.Isolate(vseg, ar)
+ vma := vseg.ValuePtr()
+ vma.numaPolicy = policy
+ vma.numaNodemask = nodemask
+ lastEnd = vseg.End()
+ if ar.End <= lastEnd {
+ return nil
+ }
+ vseg, _ = vseg.NextNonEmpty()
+ }
+}
+
// Decommit implements the semantics of Linux's madvise(MADV_DONTNEED).
func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error {
ar, ok := addr.ToRange(length)
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index 02203f79f..0af8de5b0 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -107,6 +107,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
private: opts.Private,
growsDown: opts.GrowsDown,
mlockMode: opts.MLockMode,
+ numaPolicy: linux.MPOL_DEFAULT,
id: opts.MappingIdentity,
hint: opts.Hint,
}
@@ -436,6 +437,8 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa
vma1.private != vma2.private ||
vma1.growsDown != vma2.growsDown ||
vma1.mlockMode != vma2.mlockMode ||
+ vma1.numaPolicy != vma2.numaPolicy ||
+ vma1.numaNodemask != vma2.numaNodemask ||
vma1.id != vma2.id ||
vma1.hint != vma2.hint {
return vma{}, false
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 8a8a0e4e4..ca2d5ba6f 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -63,9 +63,10 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/log",
+ "//pkg/memutil",
"//pkg/sentry/arch",
"//pkg/sentry/context",
- "//pkg/sentry/memutil",
+ "//pkg/sentry/hostmm",
"//pkg/sentry/platform",
"//pkg/sentry/safemem",
"//pkg/sentry/usage",
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 2b9924ad7..6d91f1a7b 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -32,6 +32,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/hostmm"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform"
"gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
"gvisor.googlesource.com/gvisor/pkg/sentry/usage"
@@ -162,6 +163,11 @@ type MemoryFile struct {
// evictionWG counts the number of goroutines currently performing evictions.
evictionWG sync.WaitGroup
+
+ // stopNotifyPressure stops memory cgroup pressure level
+ // notifications used to drive eviction. stopNotifyPressure is
+ // immutable.
+ stopNotifyPressure func()
}
// MemoryFileOpts provides options to NewMemoryFile.
@@ -169,6 +175,11 @@ type MemoryFileOpts struct {
// DelayedEviction controls the extent to which the MemoryFile may delay
// eviction of evictable allocations.
DelayedEviction DelayedEvictionType
+
+ // If UseHostMemcgPressure is true, use host memory cgroup pressure level
+ // notifications to determine when eviction is necessary. This option has
+ // no effect unless DelayedEviction is DelayedEvictionEnabled.
+ UseHostMemcgPressure bool
}
// DelayedEvictionType is the type of MemoryFileOpts.DelayedEviction.
@@ -186,9 +197,14 @@ const (
// evictable allocations until doing so is considered necessary to avoid
// performance degradation due to host memory pressure, or OOM kills.
//
- // As of this writing, DelayedEvictionEnabled delays evictions until the
- // reclaimer goroutine is out of work (pages to reclaim), then evicts all
- // pending evictable allocations immediately.
+ // As of this writing, the behavior of DelayedEvictionEnabled depends on
+ // whether or not MemoryFileOpts.UseHostMemcgPressure is enabled:
+ //
+ // - If UseHostMemcgPressure is true, evictions are delayed until memory
+ // pressure is indicated.
+ //
+ // - Otherwise, evictions are only delayed until the reclaimer goroutine
+ // is out of work (pages to reclaim).
DelayedEvictionEnabled
// DelayedEvictionManual requires that evictable allocations are only
@@ -292,6 +308,22 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
}
f.mappings.Store(make([]uintptr, initialSize/chunkSize))
f.reclaimCond.L = &f.mu
+
+ if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
+ stop, err := hostmm.NotifyCurrentMemcgPressureCallback(func() {
+ f.mu.Lock()
+ startedAny := f.startEvictionsLocked()
+ f.mu.Unlock()
+ if startedAny {
+ log.Debugf("pgalloc.MemoryFile performing evictions due to memcg pressure")
+ }
+ }, "low")
+ if err != nil {
+ return nil, fmt.Errorf("failed to configure memcg pressure level notifications: %v", err)
+ }
+ f.stopNotifyPressure = stop
+ }
+
go f.runReclaim() // S/R-SAFE: f.mu
// The Linux kernel contains an optional feature called "Integrity
@@ -692,9 +724,11 @@ func (f *MemoryFile) MarkEvictable(user EvictableMemoryUser, er EvictableRange)
// Kick off eviction immediately.
f.startEvictionGoroutineLocked(user, info)
case DelayedEvictionEnabled:
- // Ensure that the reclaimer goroutine is running, so that it can
- // start eviction when necessary.
- f.reclaimCond.Signal()
+ if !f.opts.UseHostMemcgPressure {
+ // Ensure that the reclaimer goroutine is running, so that it
+ // can start eviction when necessary.
+ f.reclaimCond.Signal()
+ }
}
}
}
@@ -992,11 +1026,12 @@ func (f *MemoryFile) runReclaim() {
}
f.markReclaimed(fr)
}
+
// We only get here if findReclaimable finds f.destroyed set and returns
// false.
f.mu.Lock()
- defer f.mu.Unlock()
if !f.destroyed {
+ f.mu.Unlock()
panic("findReclaimable broke out of reclaim loop, but destroyed is no longer set")
}
f.file.Close()
@@ -1016,6 +1051,13 @@ func (f *MemoryFile) runReclaim() {
}
// Similarly, invalidate f.mappings. (atomic.Value.Store(nil) panics.)
f.mappings.Store([]uintptr{})
+ f.mu.Unlock()
+
+ // This must be called without holding f.mu to avoid circular lock
+ // ordering.
+ if f.stopNotifyPressure != nil {
+ f.stopNotifyPressure()
+ }
}
func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
@@ -1029,7 +1071,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
if f.reclaimable {
break
}
- if f.opts.DelayedEviction == DelayedEvictionEnabled {
+ if f.opts.DelayedEviction == DelayedEvictionEnabled && !f.opts.UseHostMemcgPressure {
// No work to do. Evict any pending evictable allocations to
// get more reclaimable pages before going to sleep.
f.startEvictionsLocked()
@@ -1089,14 +1131,17 @@ func (f *MemoryFile) StartEvictions() {
}
// Preconditions: f.mu must be locked.
-func (f *MemoryFile) startEvictionsLocked() {
+func (f *MemoryFile) startEvictionsLocked() bool {
+ startedAny := false
for user, info := range f.evictable {
// Don't start multiple goroutines to evict the same user's
// allocations.
if !info.evicting {
f.startEvictionGoroutineLocked(user, info)
+ startedAny = true
}
}
+ return startedAny
}
// Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 9999e58f4..2931d6ddc 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -32,10 +32,10 @@ go_library(
"//pkg/atomicbitops",
"//pkg/cpuid",
"//pkg/log",
+ "//pkg/procid",
"//pkg/sentry/arch",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
- "//pkg/sentry/platform/procid",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
"//pkg/sentry/platform/safecopy",
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index f5953b96e..f8ccd86af 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -23,7 +23,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/atomicbitops"
"gvisor.googlesource.com/gvisor/pkg/log"
- "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
+ "gvisor.googlesource.com/gvisor/pkg/procid"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index e9e4a0d16..434d003a3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -20,11 +20,11 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/log",
+ "//pkg/procid",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
- "//pkg/sentry/platform/procid",
"//pkg/sentry/platform/safecopy",
"//pkg/sentry/usermem",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 83b43057f..d7800a55e 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -21,9 +21,10 @@ import (
"sync"
"syscall"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/procid"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
"gvisor.googlesource.com/gvisor/pkg/sentry/platform"
- "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
)
@@ -300,6 +301,18 @@ const (
killed
)
+func (t *thread) dumpAndPanic(message string) {
+ var regs syscall.PtraceRegs
+ message += "\n"
+ if err := t.getRegs(&regs); err == nil {
+ message += dumpRegs(&regs)
+ } else {
+ log.Warningf("unable to get registers: %v", err)
+ }
+ message += fmt.Sprintf("stubStart\t = %016x\n", stubStart)
+ panic(message)
+}
+
// wait waits for a stop event.
//
// Precondition: outcome is a valid waitOutcome.
@@ -320,7 +333,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
switch outcome {
case stopped:
if !status.Stopped() {
- panic(fmt.Sprintf("ptrace status unexpected: got %v, wanted stopped", status))
+ t.dumpAndPanic(fmt.Sprintf("ptrace status unexpected: got %v, wanted stopped", status))
}
stopSig := status.StopSignal()
if stopSig == 0 {
@@ -334,12 +347,12 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
return stopSig
case killed:
if !status.Exited() && !status.Signaled() {
- panic(fmt.Sprintf("ptrace status unexpected: got %v, wanted exited", status))
+ t.dumpAndPanic(fmt.Sprintf("ptrace status unexpected: got %v, wanted exited", status))
}
return syscall.Signal(status.ExitStatus())
default:
// Should not happen.
- panic(fmt.Sprintf("unknown outcome: %v", outcome))
+ t.dumpAndPanic(fmt.Sprintf("unknown outcome: %v", outcome))
}
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 77a0e908f..fdd21c8f8 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -17,6 +17,8 @@
package ptrace
import (
+ "fmt"
+ "strings"
"syscall"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
@@ -102,3 +104,38 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
}
return uintptr(rval), nil
}
+
+func dumpRegs(regs *syscall.PtraceRegs) string {
+ var m strings.Builder
+
+ fmt.Fprintf(&m, "Registers:\n")
+ fmt.Fprintf(&m, "\tR15\t = %016x\n", regs.R15)
+ fmt.Fprintf(&m, "\tR14\t = %016x\n", regs.R14)
+ fmt.Fprintf(&m, "\tR13\t = %016x\n", regs.R13)
+ fmt.Fprintf(&m, "\tR12\t = %016x\n", regs.R12)
+ fmt.Fprintf(&m, "\tRbp\t = %016x\n", regs.Rbp)
+ fmt.Fprintf(&m, "\tRbx\t = %016x\n", regs.Rbx)
+ fmt.Fprintf(&m, "\tR11\t = %016x\n", regs.R11)
+ fmt.Fprintf(&m, "\tR10\t = %016x\n", regs.R10)
+ fmt.Fprintf(&m, "\tR9\t = %016x\n", regs.R9)
+ fmt.Fprintf(&m, "\tR8\t = %016x\n", regs.R8)
+ fmt.Fprintf(&m, "\tRax\t = %016x\n", regs.Rax)
+ fmt.Fprintf(&m, "\tRcx\t = %016x\n", regs.Rcx)
+ fmt.Fprintf(&m, "\tRdx\t = %016x\n", regs.Rdx)
+ fmt.Fprintf(&m, "\tRsi\t = %016x\n", regs.Rsi)
+ fmt.Fprintf(&m, "\tRdi\t = %016x\n", regs.Rdi)
+ fmt.Fprintf(&m, "\tOrig_rax = %016x\n", regs.Orig_rax)
+ fmt.Fprintf(&m, "\tRip\t = %016x\n", regs.Rip)
+ fmt.Fprintf(&m, "\tCs\t = %016x\n", regs.Cs)
+ fmt.Fprintf(&m, "\tEflags\t = %016x\n", regs.Eflags)
+ fmt.Fprintf(&m, "\tRsp\t = %016x\n", regs.Rsp)
+ fmt.Fprintf(&m, "\tSs\t = %016x\n", regs.Ss)
+ fmt.Fprintf(&m, "\tFs_base\t = %016x\n", regs.Fs_base)
+ fmt.Fprintf(&m, "\tGs_base\t = %016x\n", regs.Gs_base)
+ fmt.Fprintf(&m, "\tDs\t = %016x\n", regs.Ds)
+ fmt.Fprintf(&m, "\tEs\t = %016x\n", regs.Es)
+ fmt.Fprintf(&m, "\tFs\t = %016x\n", regs.Fs)
+ fmt.Fprintf(&m, "\tGs\t = %016x\n", regs.Gs)
+
+ return m.String()
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 2c07b4ac3..914be7486 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -22,9 +22,9 @@ import (
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/procid"
"gvisor.googlesource.com/gvisor/pkg/seccomp"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
- "gvisor.googlesource.com/gvisor/pkg/sentry/platform/procid"
)
const syscallEvent syscall.Signal = 0x80
@@ -142,7 +142,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// down available calls only to what is needed.
rules := []seccomp.RuleSet{
// Rules for trapping vsyscall access.
- seccomp.RuleSet{
+ {
Rules: seccomp.SyscallRules{
syscall.SYS_GETTIMEOFDAY: {},
syscall.SYS_TIME: {},
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index c0238691d..434d7ca2e 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -406,12 +406,20 @@ func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
return nil
}
if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
- tcred := t.Credentials()
- return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+ return MakeCreds(t)
}
return nil
}
+// MakeCreds creates default SCMCredentials.
+func MakeCreds(t *kernel.Task) SCMCredentials {
+ if t == nil {
+ return nil
+ }
+ tcred := t.Credentials()
+ return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+}
+
// New creates default control messages if needed.
func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
return transport.ControlMessages{
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 44bb97b5b..7e2679ea0 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -32,7 +32,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
- "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index de4b963da..a50798cb3 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -44,7 +44,6 @@ import (
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -52,6 +51,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -227,7 +227,8 @@ type SocketOperations struct {
family int
Endpoint tcpip.Endpoint
- skType transport.SockType
+ skType linux.SockType
+ protocol int
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
@@ -252,8 +253,8 @@ type SocketOperations struct {
}
// New creates a new endpoint socket.
-func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
- if skType == transport.SockStream {
+func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Qu
family: family,
Endpoint: endpoint,
skType: skType,
+ protocol: protocol,
}), nil
}
@@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns, err := New(t, s.family, s.skType, wq, ep)
+ ns, err := New(t, s.family, s.skType, s.protocol, wq, ep)
if err != nil {
return 0, nil, 0, err
}
@@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(ns, s.family)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, syserr.FromError(e)
}
@@ -637,7 +639,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -663,7 +665,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_TYPE:
@@ -918,6 +920,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.TCP_CONGESTION:
+ if outLen <= 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.CongestionControlOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // We match linux behaviour here where it returns the lower of
+ // TCP_CA_NAME_MAX bytes or the value of the option length.
+ //
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const tcpCANameMax = 16
+
+ toCopy := tcpCANameMax
+ if outLen < tcpCANameMax {
+ toCopy = outLen
+ }
+ b := make([]byte, toCopy)
+ copy(b, v)
+ return b, nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1220,6 +1246,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_CONGESTION:
+ v := tcpip.CongestionControlOption(optVal)
+ if err := ep.SetSockOpt(v); err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ return nil
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -2281,3 +2313,51 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
}
return rv
}
+
+// State implements socket.Socket.State. State translates the internal state
+// returned by netstack to values defined by Linux.
+func (s *SocketOperations) State() uint32 {
+ if s.family != linux.AF_INET && s.family != linux.AF_INET6 {
+ // States not implemented for this socket's family.
+ return 0
+ }
+
+ if !s.isPacketBased() {
+ // TCP socket.
+ switch tcp.EndpointState(s.Endpoint.State()) {
+ case tcp.StateEstablished:
+ return linux.TCP_ESTABLISHED
+ case tcp.StateSynSent:
+ return linux.TCP_SYN_SENT
+ case tcp.StateSynRecv:
+ return linux.TCP_SYN_RECV
+ case tcp.StateFinWait1:
+ return linux.TCP_FIN_WAIT1
+ case tcp.StateFinWait2:
+ return linux.TCP_FIN_WAIT2
+ case tcp.StateTimeWait:
+ return linux.TCP_TIME_WAIT
+ case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError:
+ return linux.TCP_CLOSE
+ case tcp.StateCloseWait:
+ return linux.TCP_CLOSE_WAIT
+ case tcp.StateLastAck:
+ return linux.TCP_LAST_ACK
+ case tcp.StateListen:
+ return linux.TCP_LISTEN
+ case tcp.StateClosing:
+ return linux.TCP_CLOSING
+ default:
+ // Internal or unknown state.
+ return 0
+ }
+ }
+
+ // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.skType, s.protocol
+}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index ec930d8d5..516582828 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -23,7 +23,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
@@ -42,7 +41,7 @@ type provider struct {
// getTransportProtocol figures out transport protocol. Currently only TCP,
// UDP, and ICMP are supported.
-func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
@@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco
}
// Socket creates a new socket object for the AF_INET or AF_INET6 family.
-func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
if stack == nil {
@@ -112,11 +111,11 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int
return nil, syserr.TranslateNetstackError(e)
}
- return New(t, p.family, stype, wq, ep)
+ return New(t, p.family, stype, protocol, wq, ep)
}
// Pair just returns nil sockets (not supported).
-func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a469af7ac..975f47bc3 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -30,7 +30,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
- "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/syserror",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 41f9693bb..c62c8d8f1 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -19,7 +19,9 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
"gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
@@ -28,7 +30,6 @@ import (
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/syserror"
@@ -55,15 +56,22 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
- fd int // must be O_NONBLOCK
- queue waiter.Queue
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ fd int // must be O_NONBLOCK
+ queue waiter.Queue
}
var _ = socket.Socket(&socketOperations{})
-func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
- s := &socketOperations{family: family, fd: fd}
+func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ }
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
@@ -221,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
}
- f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
if err != nil {
syscall.Close(fd)
return 0, nil, 0, err
@@ -232,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
}
kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(f, s.family)
+ t.Kernel().RecordSocket(f)
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
@@ -519,12 +527,39 @@ func translateIOSyscallError(err error) error {
return err
}
+// State implements socket.Socket.State.
+func (s *socketOperations) State() uint32 {
+ info := linux.TCPInfo{}
+ buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
+ if err != nil {
+ if err != syscall.ENOPROTOOPT {
+ log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
+ }
+ // For non-TCP sockets, silently ignore the failure.
+ return 0
+ }
+ if len(buf) != linux.SizeOfTCPInfo {
+ // Unmarshal below will panic if getsockopt returns a buffer of
+ // unexpected size.
+ log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
+ return 0
+ }
+
+ binary.Unmarshal(buf, usermem.ByteOrder, &info)
+ return uint32(info.State)
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check that we are using the host network stack.
stack := t.NetworkContext()
if stack == nil {
@@ -535,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
}
// Only accept TCP and UDP.
- stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ stype := stypeflags & linux.SOCK_TYPE_MASK
switch stype {
case syscall.SOCK_STREAM:
switch protocol {
@@ -558,15 +593,15 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
// Conservatively ignore all flags specified by the application and add
// SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
// to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
- fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
return nil, syserr.FromError(err)
}
- return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 76cf12fd4..5dc103877 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -22,7 +22,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
)
@@ -66,10 +65,10 @@ type socketProvider struct {
}
// Socket implements socket.Provider.Socket.
-func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Netlink sockets must be specified as datagram or raw, but they
// behave the same regardless of type.
- if stype != transport.SockDgram && stype != transport.SockRaw {
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
return nil, syserr.ErrSocketNotSupported
}
@@ -83,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol
return nil, err
}
- s, err := NewSocket(t, p)
+ s, err := NewSocket(t, stype, p)
if err != nil {
return nil, err
}
@@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol
}
// Pair implements socket.Provider.Pair by returning an error.
-func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
// Netlink sockets never supports creating socket pairs.
return nil, nil, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index afd06ca33..62659784a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -80,6 +80,10 @@ type Socket struct {
// protocol is the netlink protocol implementation.
protocol Protocol
+ // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
+ // netlink sockets.
+ skType linux.SockType
+
// ep is a datagram unix endpoint used to buffer messages sent from the
// kernel to userspace. RecvMsg reads messages from this endpoint.
ep transport.Endpoint
@@ -105,7 +109,7 @@ type Socket struct {
var _ socket.Socket = (*Socket)(nil)
// NewSocket creates a new Socket.
-func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
+func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
// Datagram endpoint used to buffer kernel -> user messages.
ep := transport.NewConnectionless()
@@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
return &Socket{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
+ skType: skType,
ep: ep,
connection: connection,
sendBufferSize: defaultSendBufferSize,
@@ -616,3 +621,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence,
n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
return int64(n), err.ToError()
}
+
+// State implements socket.Socket.State.
+func (s *Socket) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
+ return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
+}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 4da14a1e0..33ba20de7 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/sentry/socket/hostinet",
"//pkg/sentry/socket/rpcinet/conn",
"//pkg/sentry/socket/rpcinet/notifier",
- "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 55e0b6665..c22ff1ff0 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -32,7 +32,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -54,7 +53,10 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
rpcConn *conn.RPCConnection
@@ -70,7 +72,7 @@ type socketOperations struct {
var _ = socket.Socket(&socketOperations{})
// New creates a new RPC socket.
-func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) {
+func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) {
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
<-c
@@ -87,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, pr
defer dirent.DecRef()
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
family: family,
+ stype: skType,
+ protocol: protocol,
wq: &wq,
fd: fd,
rpcConn: stack.rpcConn,
@@ -333,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if err != nil {
return 0, nil, 0, syserr.FromError(err)
}
- t.Kernel().RecordSocket(file, s.family)
+ t.Kernel().RecordSocket(file)
if peerRequested {
return fd, payload.Address.Address, payload.Address.Length, nil
@@ -830,12 +834,23 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
}
+// State implements socket.Socket.State.
+func (s *socketOperations) State() uint32 {
+ // TODO(b/127845868): Define a new rpc to query the socket state.
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check that we are using the RPC network stack.
stack := t.NetworkContext()
if stack == nil {
@@ -851,7 +866,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
//
// Try to restrict the flags we will accept to minimize backwards
// incompatibility with netstack.
- stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ stype := stypeflags & linux.SOCK_TYPE_MASK
switch stype {
case syscall.SOCK_STREAM:
switch protocol {
@@ -871,11 +886,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
return nil, nil
}
- return newSocketFile(t, s, p.family, stype, 0)
+ return newSocketFile(t, s, p.family, stype, protocol)
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 9393acd28..d60944b6b 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -116,6 +116,13 @@ type Socket interface {
// SendTimeout gets the current timeout (in ns) for send operations. Zero
// means no timeout, and negative means DONTWAIT.
SendTimeout() int64
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs. The returned state value is protocol-specific.
+ State() uint32
+
+ // Type returns the family, socket type and protocol of the socket.
+ Type() (family int, skType linux.SockType, protocol int)
}
// Provider is the interface implemented by providers of sockets for specific
@@ -126,12 +133,12 @@ type Provider interface {
// If a nil Socket _and_ a nil error is returned, it means that the
// protocol is not supported. A non-nil error should only be returned
// if the protocol is supported, but an error occurs during creation.
- Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error)
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error)
// Pair creates a pair of connected sockets.
//
// See Socket for error information.
- Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
}
// families holds a map of all known address families and their providers.
@@ -145,14 +152,14 @@ func RegisterProvider(family int, provider Provider) {
}
// New creates a new socket with the given family, type and protocol.
-func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
for _, p := range families[family] {
s, err := p.Socket(t, stype, protocol)
if err != nil {
return nil, err
}
if s != nil {
- t.Kernel().RecordSocket(s, family)
+ t.Kernel().RecordSocket(s)
return s, nil
}
}
@@ -162,7 +169,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f
// Pair creates a new connected socket pair with the given family, type and
// protocol.
-func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
providers, ok := families[family]
if !ok {
return nil, nil, syserr.ErrAddressFamilyNotSupported
@@ -175,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*
}
if s1 != nil && s2 != nil {
k := t.Kernel()
- k.RecordSocket(s1, family)
- k.RecordSocket(s2, family)
+ k.RecordSocket(s1)
+ k.RecordSocket(s2)
return s1, s2, nil
}
}
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 5a2de0c4c..52f324eed 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -28,6 +28,7 @@ go_library(
importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
"//pkg/ilist",
"//pkg/refs",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 18e492862..db79ac904 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -17,6 +17,7 @@ package transport
import (
"sync"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/waiter"
@@ -44,7 +45,7 @@ type ConnectingEndpoint interface {
// Type returns the socket type, typically either SockStream or
// SockSeqpacket. The connection attempt must be aborted if this
// value doesn't match the ConnectableEndpoint's type.
- Type() SockType
+ Type() linux.SockType
// GetLocalAddress returns the bound path.
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
@@ -100,7 +101,7 @@ type connectionedEndpoint struct {
// stype is used by connecting sockets to ensure that they are the
// same type. The value is typically either tcpip.SockSeqpacket or
// tcpip.SockStream.
- stype SockType
+ stype linux.SockType
// acceptedChan is per the TCP endpoint implementation. Note that the
// sockets in this channel are _already in the connected state_, and
@@ -111,7 +112,7 @@ type connectionedEndpoint struct {
}
// NewConnectioned creates a new unbound connectionedEndpoint.
-func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
+func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -121,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
-func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
a := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -138,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- if stype == SockStream {
+ if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
} else {
@@ -162,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
-func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
@@ -177,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 {
}
// Type implements ConnectingEndpoint.Type and Endpoint.Type.
-func (e *connectionedEndpoint) Type() SockType {
+func (e *connectionedEndpoint) Type() linux.SockType {
return e.stype
}
@@ -293,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- if e.stype == SockStream {
+ if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
ne.receiver = &queueReceiver{readQueue: writeQueue}
@@ -308,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
writeQueue: writeQueue,
}
readQueue.IncRef()
- if e.stype == SockStream {
+ if e.stype == linux.SOCK_STREAM {
returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
} else {
returnConnect(&queueReceiver{readQueue: readQueue}, connected)
@@ -428,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser
func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
// Stream sockets do not support specifying the endpoint. Seqpacket
// sockets ignore the passed endpoint.
- if e.stype == SockStream && to != nil {
+ if e.stype == linux.SOCK_STREAM && to != nil {
return 0, syserr.ErrNotSupported
}
return e.baseEndpoint.SendMsg(data, c, to)
@@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return ready
}
+
+// State implements socket.Socket.State.
+func (e *connectionedEndpoint) State() uint32 {
+ if e.Connected() {
+ return linux.SS_CONNECTED
+ }
+ return linux.SS_UNCONNECTED
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 43ff875e4..81ebfba10 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -15,6 +15,7 @@
package transport
import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/waiter"
@@ -118,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo
}
// Type implements Endpoint.Type.
-func (e *connectionlessEndpoint) Type() SockType {
- return SockDgram
+func (e *connectionlessEndpoint) Type() linux.SockType {
+ return linux.SOCK_DGRAM
}
// Connect attempts to connect directly to server.
@@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa
return ready
}
+
+// State implements socket.Socket.State.
+func (e *connectionlessEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ switch {
+ case e.isBound():
+ return linux.SS_UNCONNECTED
+ case e.Connected():
+ return linux.SS_CONNECTING
+ default:
+ return linux.SS_DISCONNECTING
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index b734b4c20..5c55c529e 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -19,6 +19,7 @@ import (
"sync"
"sync/atomic"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
@@ -28,21 +29,6 @@ import (
// initialLimit is the starting limit for the socket buffers.
const initialLimit = 16 * 1024
-// A SockType is a type (as opposed to family) of sockets. These are enumerated
-// in the syscall package as syscall.SOCK_* constants.
-type SockType int
-
-const (
- // SockStream corresponds to syscall.SOCK_STREAM.
- SockStream SockType = 1
- // SockDgram corresponds to syscall.SOCK_DGRAM.
- SockDgram SockType = 2
- // SockRaw corresponds to syscall.SOCK_RAW.
- SockRaw SockType = 3
- // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
- SockSeqpacket SockType = 5
-)
-
// A RightsControlMessage is a control message containing FDs.
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
@@ -175,7 +161,7 @@ type Endpoint interface {
// Type return the socket type, typically either SockStream, SockDgram
// or SockSeqpacket.
- Type() SockType
+ Type() linux.SockType
// GetLocalAddress returns the address to which the endpoint is bound.
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
@@ -191,6 +177,10 @@ type Endpoint interface {
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs.
+ State() uint32
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -237,6 +227,10 @@ type BoundEndpoint interface {
// endpoint.
UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error)
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
// Release releases any resources held by the BoundEndpoint. It must be
// called before dropping all references to a BoundEndpoint returned by a
// function.
@@ -621,7 +615,7 @@ type connectedEndpoint struct {
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Type implements Endpoint.Type.
- Type() SockType
+ Type() linux.SockType
}
writeQueue *queue
@@ -645,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages,
}
truncate := false
- if e.endpoint.Type() == SockStream {
+ if e.endpoint.Type() == linux.SOCK_STREAM {
// Since stream sockets don't preserve message boundaries, we
// can write only as much of the message as fits in the queue.
truncate = true
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 1414be0c6..b07e8d67b 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -17,6 +17,7 @@
package unix
import (
+ "fmt"
"strings"
"syscall"
@@ -55,22 +56,22 @@ type SocketOperations struct {
refs.AtomicRefCount
socket.SendReceiveTimeout
- ep transport.Endpoint
- isPacket bool
+ ep transport.Endpoint
+ stype linux.SockType
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
- isPacket: isPacket,
+ ep: ep,
+ stype: stype,
})
}
@@ -88,6 +89,18 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
+func (s *SocketOperations) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
// Endpoint extracts the transport.Endpoint.
func (s *SocketOperations) Endpoint() transport.Endpoint {
return s.ep
@@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep, s.isPacket)
+ ns := New(t, ep, s.stype)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, nil, 0, syserr.FromError(e)
}
- t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, nil
}
@@ -385,6 +398,10 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
defer ep.Release()
w.To = ep
+
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
}
n, err := src.CopyInTo(t, &w)
@@ -483,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -516,7 +534,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
var from interface{}
var fromLen uint32
- if r.From != nil {
+ if r.From != nil && len([]byte(r.From.Addr)) != 0 {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
@@ -524,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msgFlags |= linux.MSG_CTRUNC
}
- if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- if s.isPacket && n < int64(r.MsgSize) {
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
@@ -566,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- if s.isPacket && n < int64(r.MsgSize) {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
@@ -592,11 +610,22 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
+// State implements socket.Socket.State.
+func (s *SocketOperations) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
// provider is a unix domain socket provider.
type provider struct{}
// Socket returns a new unix domain socket.
-func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check arguments.
if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
return nil, syserr.ErrProtocolNotSupported
@@ -604,43 +633,36 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
// Create the endpoint and socket.
var ep transport.Endpoint
- var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
- isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_SEQPACKET:
- isPacket = true
- fallthrough
- case linux.SOCK_STREAM:
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep, isPacket), nil
+ return New(t, ep, stype), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
-func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Check arguments.
if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
return nil, nil, syserr.ErrProtocolNotSupported
}
- var isPacket bool
switch stype {
- case linux.SOCK_STREAM:
- case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
- isPacket = true
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ // Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1, isPacket)
- s2 := New(t, ep2, isPacket)
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
return s1, s2, nil
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index dbe53b9a2..0b5ef84c4 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -76,13 +76,13 @@ var SocketFamily = abi.ValueSet{
// SocketType are the possible socket(2) types.
var SocketType = abi.ValueSet{
- linux.SOCK_STREAM: "SOCK_STREAM",
- linux.SOCK_DGRAM: "SOCK_DGRAM",
- linux.SOCK_RAW: "SOCK_RAW",
- linux.SOCK_RDM: "SOCK_RDM",
- linux.SOCK_SEQPACKET: "SOCK_SEQPACKET",
- linux.SOCK_DCCP: "SOCK_DCCP",
- linux.SOCK_PACKET: "SOCK_PACKET",
+ uint64(linux.SOCK_STREAM): "SOCK_STREAM",
+ uint64(linux.SOCK_DGRAM): "SOCK_DGRAM",
+ uint64(linux.SOCK_RAW): "SOCK_RAW",
+ uint64(linux.SOCK_RDM): "SOCK_RDM",
+ uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET",
+ uint64(linux.SOCK_DCCP): "SOCK_DCCP",
+ uint64(linux.SOCK_PACKET): "SOCK_PACKET",
}
// SocketFlagSet are the possible socket(2) flags.
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index f76989ae2..1c057526b 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -19,6 +19,7 @@ go_library(
"sys_identity.go",
"sys_inotify.go",
"sys_lseek.go",
+ "sys_mempolicy.go",
"sys_mmap.go",
"sys_mount.go",
"sys_pipe.go",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 1ba3695fb..72146ea63 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -92,6 +92,10 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin
// TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
// also be sent to the application.
return nil
+ case syserror.ECONNRESET:
+ // For TCP sendfile connections, we may have a reset. But we
+ // should just return n as the result.
+ return nil
case syserror.ErrWouldBlock:
// Syscall would block, but completed a partial read/write.
// This case should only be returned by IssueIO for nonblocking
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 3e4d312af..5251c2463 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -34,33 +34,6 @@ const _AUDIT_ARCH_X86_64 = 0xc000003e
// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
// numbers from Linux 4.4. The entries commented out are those syscalls we
// don't currently support.
-//
-// Syscall support is documented as annotations in Go comments of the form:
-// @Syscall(<name>, <key:value>, ...)
-//
-// Supported args and values are:
-//
-// - arg: A syscall option. This entry only applies to the syscall when given
-// this option.
-// - support: Indicates support level
-// - UNIMPLEMENTED: Unimplemented (default, implies returns:ENOSYS)
-// - PARTIAL: Partial support. Details should be provided in note.
-// - FULL: Full support
-// - returns: Indicates a known return value. Values are syscall errors. This
-// is treated as a string so you can use something like
-// "returns:EPERM or ENOSYS".
-// - issue: A Github issue number.
-// - note: A note
-//
-// Example:
-// // @Syscall(mmap, arg:MAP_PRIVATE, support:FULL, note:Private memory fully supported)
-// // @Syscall(mmap, arg:MAP_SHARED, issue:123, note:Shared memory not supported)
-// // @Syscall(setxattr, returns:ENOTSUP, note:Requires file system support)
-//
-// Annotations should be placed as close to their implementation as possible
-// (preferrably as part of a supporting function's Godoc) and should be
-// updated as syscall support changes. Unimplemented syscalls are documented
-// here due to their lack of a supporting function or method.
var AMD64 = &kernel.SyscallTable{
OS: abi.Linux,
Arch: arch.AMD64,
@@ -74,405 +47,338 @@ var AMD64 = &kernel.SyscallTable{
Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016",
},
AuditNumber: _AUDIT_ARCH_X86_64,
- Table: map[uintptr]kernel.SyscallFn{
- 0: Read,
- 1: Write,
- 2: Open,
- 3: Close,
- 4: Stat,
- 5: Fstat,
- 6: Lstat,
- 7: Poll,
- 8: Lseek,
- 9: Mmap,
- 10: Mprotect,
- 11: Munmap,
- 12: Brk,
- 13: RtSigaction,
- 14: RtSigprocmask,
- 15: RtSigreturn,
- 16: Ioctl,
- 17: Pread64,
- 18: Pwrite64,
- 19: Readv,
- 20: Writev,
- 21: Access,
- 22: Pipe,
- 23: Select,
- 24: SchedYield,
- 25: Mremap,
- 26: Msync,
- 27: Mincore,
- 28: Madvise,
- 29: Shmget,
- 30: Shmat,
- 31: Shmctl,
- 32: Dup,
- 33: Dup2,
- 34: Pause,
- 35: Nanosleep,
- 36: Getitimer,
- 37: Alarm,
- 38: Setitimer,
- 39: Getpid,
- 40: Sendfile,
- 41: Socket,
- 42: Connect,
- 43: Accept,
- 44: SendTo,
- 45: RecvFrom,
- 46: SendMsg,
- 47: RecvMsg,
- 48: Shutdown,
- 49: Bind,
- 50: Listen,
- 51: GetSockName,
- 52: GetPeerName,
- 53: SocketPair,
- 54: SetSockOpt,
- 55: GetSockOpt,
- 56: Clone,
- 57: Fork,
- 58: Vfork,
- 59: Execve,
- 60: Exit,
- 61: Wait4,
- 62: Kill,
- 63: Uname,
- 64: Semget,
- 65: Semop,
- 66: Semctl,
- 67: Shmdt,
- // 68: @Syscall(Msgget), TODO(b/29354921)
- // 69: @Syscall(Msgsnd), TODO(b/29354921)
- // 70: @Syscall(Msgrcv), TODO(b/29354921)
- // 71: @Syscall(Msgctl), TODO(b/29354921)
- 72: Fcntl,
- 73: Flock,
- 74: Fsync,
- 75: Fdatasync,
- 76: Truncate,
- 77: Ftruncate,
- 78: Getdents,
- 79: Getcwd,
- 80: Chdir,
- 81: Fchdir,
- 82: Rename,
- 83: Mkdir,
- 84: Rmdir,
- 85: Creat,
- 86: Link,
- 87: Unlink,
- 88: Symlink,
- 89: Readlink,
- 90: Chmod,
- 91: Fchmod,
- 92: Chown,
- 93: Fchown,
- 94: Lchown,
- 95: Umask,
- 96: Gettimeofday,
- 97: Getrlimit,
- 98: Getrusage,
- 99: Sysinfo,
- 100: Times,
- 101: Ptrace,
- 102: Getuid,
- 103: Syslog,
- 104: Getgid,
- 105: Setuid,
- 106: Setgid,
- 107: Geteuid,
- 108: Getegid,
- 109: Setpgid,
- 110: Getppid,
- 111: Getpgrp,
- 112: Setsid,
- 113: Setreuid,
- 114: Setregid,
- 115: Getgroups,
- 116: Setgroups,
- 117: Setresuid,
- 118: Getresuid,
- 119: Setresgid,
- 120: Getresgid,
- 121: Getpgid,
- // 122: @Syscall(Setfsuid), TODO(b/112851702)
- // 123: @Syscall(Setfsgid), TODO(b/112851702)
- 124: Getsid,
- 125: Capget,
- 126: Capset,
- 127: RtSigpending,
- 128: RtSigtimedwait,
- 129: RtSigqueueinfo,
- 130: RtSigsuspend,
- 131: Sigaltstack,
- 132: Utime,
- 133: Mknod,
- // @Syscall(Uselib, note:Obsolete)
- 134: syscalls.Error(syscall.ENOSYS),
- // @Syscall(SetPersonality, returns:EINVAL, note:Unable to change personality)
- 135: syscalls.ErrorWithEvent(syscall.EINVAL),
- // @Syscall(Ustat, note:Needs filesystem support)
- 136: syscalls.ErrorWithEvent(syscall.ENOSYS),
- 137: Statfs,
- 138: Fstatfs,
- // 139: @Syscall(Sysfs), TODO(gvisor.dev/issue/165)
- 140: Getpriority,
- 141: Setpriority,
- // @Syscall(SchedSetparam, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
- 142: syscalls.CapError(linux.CAP_SYS_NICE), // requires cap_sys_nice
- 143: SchedGetparam,
- 144: SchedSetscheduler,
- 145: SchedGetscheduler,
- 146: SchedGetPriorityMax,
- 147: SchedGetPriorityMin,
- // @Syscall(SchedRrGetInterval, returns:EPERM)
- 148: syscalls.ErrorWithEvent(syscall.EPERM),
- 149: Mlock,
- 150: Munlock,
- 151: Mlockall,
- 152: Munlockall,
- // @Syscall(Vhangup, returns:EPERM)
- 153: syscalls.CapError(linux.CAP_SYS_TTY_CONFIG),
- // @Syscall(ModifyLdt, returns:EPERM)
- 154: syscalls.Error(syscall.EPERM),
- // @Syscall(PivotRoot, returns:EPERM)
- 155: syscalls.Error(syscall.EPERM),
- // @Syscall(Sysctl, returns:EPERM)
- 156: syscalls.Error(syscall.EPERM), // syscall is "worthless"
- 157: Prctl,
- 158: ArchPrctl,
- // @Syscall(Adjtimex, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_time; ENOSYS otherwise)
- 159: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
- 160: Setrlimit,
- 161: Chroot,
- 162: Sync,
- // @Syscall(Acct, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_pacct; ENOSYS otherwise)
- 163: syscalls.CapError(linux.CAP_SYS_PACCT), // requires cap_sys_pacct
- // @Syscall(Settimeofday, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_time; ENOSYS otherwise)
- 164: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
- 165: Mount,
- 166: Umount2,
- // @Syscall(Swapon, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 167: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
- // @Syscall(Swapoff, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 168: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
- // @Syscall(Reboot, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
- 169: syscalls.CapError(linux.CAP_SYS_BOOT), // requires cap_sys_boot
- 170: Sethostname,
- 171: Setdomainname,
- // @Syscall(Iopl, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_rawio; ENOSYS otherwise)
- 172: syscalls.CapError(linux.CAP_SYS_RAWIO), // requires cap_sys_rawio
- // @Syscall(Ioperm, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_rawio; ENOSYS otherwise)
- 173: syscalls.CapError(linux.CAP_SYS_RAWIO), // requires cap_sys_rawio
- // @Syscall(CreateModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
- 174: syscalls.CapError(linux.CAP_SYS_MODULE), // CreateModule, requires cap_sys_module
- // @Syscall(InitModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
- 175: syscalls.CapError(linux.CAP_SYS_MODULE), // requires cap_sys_module
- // @Syscall(DeleteModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
- 176: syscalls.CapError(linux.CAP_SYS_MODULE), // requires cap_sys_module
- // @Syscall(GetKernelSyms, note:Not supported in > 2.6)
- 177: syscalls.Error(syscall.ENOSYS),
- // @Syscall(QueryModule, note:Not supported in > 2.6)
- 178: syscalls.Error(syscall.ENOSYS),
- // @Syscall(Quotactl, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 179: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin (most operations)
- // @Syscall(Nfsservctl, note:Does not exist > 3.1)
- 180: syscalls.Error(syscall.ENOSYS),
- // @Syscall(Getpmsg, note:Not implemented in Linux)
- 181: syscalls.Error(syscall.ENOSYS),
- // @Syscall(Putpmsg, note:Not implemented in Linux)
- 182: syscalls.Error(syscall.ENOSYS),
- // @Syscall(AfsSyscall, note:Not implemented in Linux)
- 183: syscalls.Error(syscall.ENOSYS),
- // @Syscall(Tuxcall, note:Not implemented in Linux)
- 184: syscalls.Error(syscall.ENOSYS),
- // @Syscall(Security, note:Not implemented in Linux)
- 185: syscalls.Error(syscall.ENOSYS),
- 186: Gettid,
- 187: nil, // @Syscall(Readahead), TODO(b/29351341)
- // @Syscall(Setxattr, returns:ENOTSUP, note:Requires filesystem support)
- 188: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Lsetxattr, returns:ENOTSUP, note:Requires filesystem support)
- 189: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Fsetxattr, returns:ENOTSUP, note:Requires filesystem support)
- 190: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Getxattr, returns:ENOTSUP, note:Requires filesystem support)
- 191: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Lgetxattr, returns:ENOTSUP, note:Requires filesystem support)
- 192: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Fgetxattr, returns:ENOTSUP, note:Requires filesystem support)
- 193: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Listxattr, returns:ENOTSUP, note:Requires filesystem support)
- 194: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Llistxattr, returns:ENOTSUP, note:Requires filesystem support)
- 195: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Flistxattr, returns:ENOTSUP, note:Requires filesystem support)
- 196: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Removexattr, returns:ENOTSUP, note:Requires filesystem support)
- 197: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Lremovexattr, returns:ENOTSUP, note:Requires filesystem support)
- 198: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- // @Syscall(Fremovexattr, returns:ENOTSUP, note:Requires filesystem support)
- 199: syscalls.ErrorWithEvent(syscall.ENOTSUP),
- 200: Tkill,
- 201: Time,
- 202: Futex,
- 203: SchedSetaffinity,
- 204: SchedGetaffinity,
- // @Syscall(SetThreadArea, note:Expected to return ENOSYS on 64-bit)
- 205: syscalls.Error(syscall.ENOSYS),
- 206: IoSetup,
- 207: IoDestroy,
- 208: IoGetevents,
- 209: IoSubmit,
- 210: IoCancel,
- // @Syscall(GetThreadArea, note:Expected to return ENOSYS on 64-bit)
- 211: syscalls.Error(syscall.ENOSYS),
- // @Syscall(LookupDcookie, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 212: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin
- 213: EpollCreate,
- // @Syscall(EpollCtlOld, note:Deprecated)
- 214: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated (afaik, unused)
- // @Syscall(EpollWaitOld, note:Deprecated)
- 215: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated (afaik, unused)
- // @Syscall(RemapFilePages, note:Deprecated)
- 216: syscalls.ErrorWithEvent(syscall.ENOSYS), // deprecated since 3.16
- 217: Getdents64,
- 218: SetTidAddress,
- 219: RestartSyscall,
- // 220: @Syscall(Semtimedop), TODO(b/29354920)
- 221: Fadvise64,
- 222: TimerCreate,
- 223: TimerSettime,
- 224: TimerGettime,
- 225: TimerGetoverrun,
- 226: TimerDelete,
- 227: ClockSettime,
- 228: ClockGettime,
- 229: ClockGetres,
- 230: ClockNanosleep,
- 231: ExitGroup,
- 232: EpollWait,
- 233: EpollCtl,
- 234: Tgkill,
- 235: Utimes,
- // @Syscall(Vserver, note:Not implemented by Linux)
- 236: syscalls.Error(syscall.ENOSYS), // Vserver, not implemented by Linux
- // @Syscall(Mbind, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise), TODO(b/117792295)
- 237: syscalls.CapError(linux.CAP_SYS_NICE), // may require cap_sys_nice
- 238: SetMempolicy,
- 239: GetMempolicy,
- // 240: @Syscall(MqOpen), TODO(b/29354921)
- // 241: @Syscall(MqUnlink), TODO(b/29354921)
- // 242: @Syscall(MqTimedsend), TODO(b/29354921)
- // 243: @Syscall(MqTimedreceive), TODO(b/29354921)
- // 244: @Syscall(MqNotify), TODO(b/29354921)
- // 245: @Syscall(MqGetsetattr), TODO(b/29354921)
- 246: syscalls.CapError(linux.CAP_SYS_BOOT), // kexec_load, requires cap_sys_boot
- 247: Waitid,
- // @Syscall(AddKey, returns:EACCES, note:Not available to user)
- 248: syscalls.Error(syscall.EACCES),
- // @Syscall(RequestKey, returns:EACCES, note:Not available to user)
- 249: syscalls.Error(syscall.EACCES),
- // @Syscall(Keyctl, returns:EACCES, note:Not available to user)
- 250: syscalls.Error(syscall.EACCES),
- // @Syscall(IoprioSet, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 251: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_nice or cap_sys_admin (depending)
- // @Syscall(IoprioGet, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_admin; ENOSYS otherwise)
- 252: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_nice or cap_sys_admin (depending)
- 253: InotifyInit,
- 254: InotifyAddWatch,
- 255: InotifyRmWatch,
- // @Syscall(MigratePages, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
- 256: syscalls.CapError(linux.CAP_SYS_NICE),
- 257: Openat,
- 258: Mkdirat,
- 259: Mknodat,
- 260: Fchownat,
- 261: Futimesat,
- 262: Fstatat,
- 263: Unlinkat,
- 264: Renameat,
- 265: Linkat,
- 266: Symlinkat,
- 267: Readlinkat,
- 268: Fchmodat,
- 269: Faccessat,
- 270: Pselect,
- 271: Ppoll,
- 272: Unshare,
- // @Syscall(SetRobustList, note:Obsolete)
- 273: syscalls.Error(syscall.ENOSYS),
- // @Syscall(GetRobustList, note:Obsolete)
- 274: syscalls.Error(syscall.ENOSYS),
- 275: Splice,
- // 276: @Syscall(Tee), TODO(b/29354098)
- 277: SyncFileRange,
- // 278: @Syscall(Vmsplice), TODO(b/29354098)
- // @Syscall(MovePages, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise)
- 279: syscalls.CapError(linux.CAP_SYS_NICE), // requires cap_sys_nice (mostly)
- 280: Utimensat,
- 281: EpollPwait,
- // 282: @Syscall(Signalfd), TODO(b/19846426)
- 283: TimerfdCreate,
- 284: Eventfd,
- 285: Fallocate,
- 286: TimerfdSettime,
- 287: TimerfdGettime,
- 288: Accept4,
- // 289: @Syscall(Signalfd4), TODO(b/19846426)
- 290: Eventfd2,
- 291: EpollCreate1,
- 292: Dup3,
- 293: Pipe2,
- 294: InotifyInit1,
- 295: Preadv,
- 296: Pwritev,
- 297: RtTgsigqueueinfo,
- // @Syscall(PerfEventOpen, returns:ENODEV, note:No support for perf counters)
- 298: syscalls.ErrorWithEvent(syscall.ENODEV),
- 299: RecvMMsg,
- // @Syscall(FanotifyInit, note:Needs CONFIG_FANOTIFY)
- 300: syscalls.ErrorWithEvent(syscall.ENOSYS),
- // @Syscall(FanotifyMark, note:Needs CONFIG_FANOTIFY)
- 301: syscalls.ErrorWithEvent(syscall.ENOSYS),
- 302: Prlimit64,
- // @Syscall(NameToHandleAt, returns:EOPNOTSUPP, note:Needs filesystem support)
- 303: syscalls.ErrorWithEvent(syscall.EOPNOTSUPP),
- // @Syscall(OpenByHandleAt, returns:EOPNOTSUPP, note:Needs filesystem support)
- 304: syscalls.ErrorWithEvent(syscall.EOPNOTSUPP),
- // @Syscall(ClockAdjtime, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
- 305: syscalls.CapError(linux.CAP_SYS_TIME), // requires cap_sys_time
- 306: Syncfs,
- 307: SendMMsg,
- // 308: @Syscall(Setns), TODO(b/29354995)
- 309: Getcpu,
- // 310: @Syscall(ProcessVmReadv), TODO(gvisor.dev/issue/158) may require cap_sys_ptrace
- // 311: @Syscall(ProcessVmWritev), TODO(gvisor.dev/issue/158) may require cap_sys_ptrace
- // @Syscall(Kcmp, returns:EPERM or ENOSYS, note:Requires cap_sys_ptrace)
- 312: syscalls.CapError(linux.CAP_SYS_PTRACE),
- // @Syscall(FinitModule, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_module; ENOSYS otherwise)
- 313: syscalls.CapError(linux.CAP_SYS_MODULE),
- // 314: @Syscall(SchedSetattr), TODO(b/118902272), we have no scheduler
- // 315: @Syscall(SchedGetattr), TODO(b/118902272), we have no scheduler
- // 316: @Syscall(Renameat2), TODO(b/118902772)
- 317: Seccomp,
- 318: GetRandom,
- 319: MemfdCreate,
- // @Syscall(KexecFileLoad, EPERM or ENOSYS, note:Infeasible to support. Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
- 320: syscalls.CapError(linux.CAP_SYS_BOOT),
- // @Syscall(Bpf, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_boot; ENOSYS otherwise)
- 321: syscalls.CapError(linux.CAP_SYS_ADMIN), // requires cap_sys_admin for all commands
- // 322: @Syscall(Execveat), TODO(b/118901836)
- // 323: @Syscall(Userfaultfd), TODO(b/118906345)
- // 324: @Syscall(Membarrier), TODO(b/118904897)
- 325: Mlock2,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.Supported("read", Read),
+ 1: syscalls.Supported("write", Write),
+ 2: syscalls.Supported("open", Open),
+ 3: syscalls.Supported("close", Close),
+ 4: syscalls.Undocumented("stat", Stat),
+ 5: syscalls.Undocumented("fstat", Fstat),
+ 6: syscalls.Undocumented("lstat", Lstat),
+ 7: syscalls.Undocumented("poll", Poll),
+ 8: syscalls.Undocumented("lseek", Lseek),
+ 9: syscalls.Undocumented("mmap", Mmap),
+ 10: syscalls.Undocumented("mprotect", Mprotect),
+ 11: syscalls.Undocumented("munmap", Munmap),
+ 12: syscalls.Undocumented("brk", Brk),
+ 13: syscalls.Undocumented("rt_sigaction", RtSigaction),
+ 14: syscalls.Undocumented("rt_sigprocmask", RtSigprocmask),
+ 15: syscalls.Undocumented("rt_sigreturn", RtSigreturn),
+ 16: syscalls.Undocumented("ioctl", Ioctl),
+ 17: syscalls.Undocumented("pread64", Pread64),
+ 18: syscalls.Undocumented("pwrite64", Pwrite64),
+ 19: syscalls.Undocumented("readv", Readv),
+ 20: syscalls.Undocumented("writev", Writev),
+ 21: syscalls.Undocumented("access", Access),
+ 22: syscalls.Undocumented("pipe", Pipe),
+ 23: syscalls.Undocumented("select", Select),
+ 24: syscalls.Undocumented("sched_yield", SchedYield),
+ 25: syscalls.Undocumented("mremap", Mremap),
+ 26: syscalls.Undocumented("msync", Msync),
+ 27: syscalls.Undocumented("mincore", Mincore),
+ 28: syscalls.Undocumented("madvise", Madvise),
+ 29: syscalls.Undocumented("shmget", Shmget),
+ 30: syscalls.Undocumented("shmat", Shmat),
+ 31: syscalls.Undocumented("shmctl", Shmctl),
+ 32: syscalls.Undocumented("dup", Dup),
+ 33: syscalls.Undocumented("dup2", Dup2),
+ 34: syscalls.Undocumented("pause", Pause),
+ 35: syscalls.Undocumented("nanosleep", Nanosleep),
+ 36: syscalls.Undocumented("getitimer", Getitimer),
+ 37: syscalls.Undocumented("alarm", Alarm),
+ 38: syscalls.Undocumented("setitimer", Setitimer),
+ 39: syscalls.Undocumented("getpid", Getpid),
+ 40: syscalls.Undocumented("sendfile", Sendfile),
+ 41: syscalls.Undocumented("socket", Socket),
+ 42: syscalls.Undocumented("connect", Connect),
+ 43: syscalls.Undocumented("accept", Accept),
+ 44: syscalls.Undocumented("sendto", SendTo),
+ 45: syscalls.Undocumented("recvfrom", RecvFrom),
+ 46: syscalls.Undocumented("sendmsg", SendMsg),
+ 47: syscalls.Undocumented("recvmsg", RecvMsg),
+ 48: syscalls.Undocumented("shutdown", Shutdown),
+ 49: syscalls.Undocumented("bind", Bind),
+ 50: syscalls.Undocumented("listen", Listen),
+ 51: syscalls.Undocumented("getsockname", GetSockName),
+ 52: syscalls.Undocumented("getpeername", GetPeerName),
+ 53: syscalls.Undocumented("socketpair", SocketPair),
+ 54: syscalls.Undocumented("setsockopt", SetSockOpt),
+ 55: syscalls.Undocumented("getsockopt", GetSockOpt),
+ 56: syscalls.Undocumented("clone", Clone),
+ 57: syscalls.Undocumented("fork", Fork),
+ 58: syscalls.Undocumented("vfork", Vfork),
+ 59: syscalls.Undocumented("execve", Execve),
+ 60: syscalls.Undocumented("exit", Exit),
+ 61: syscalls.Undocumented("wait4", Wait4),
+ 62: syscalls.Undocumented("kill", Kill),
+ 63: syscalls.Undocumented("uname", Uname),
+ 64: syscalls.Undocumented("semget", Semget),
+ 65: syscalls.Undocumented("semop", Semop),
+ 66: syscalls.Undocumented("semctl", Semctl),
+ 67: syscalls.Undocumented("shmdt", Shmdt),
+ 68: syscalls.ErrorWithEvent("msgget", syscall.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.ErrorWithEvent("msgsnd", syscall.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 70: syscalls.ErrorWithEvent("msgrcv", syscall.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 71: syscalls.ErrorWithEvent("msgctl", syscall.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 72: syscalls.Undocumented("fcntl", Fcntl),
+ 73: syscalls.Undocumented("flock", Flock),
+ 74: syscalls.Undocumented("fsync", Fsync),
+ 75: syscalls.Undocumented("fdatasync", Fdatasync),
+ 76: syscalls.Undocumented("truncate", Truncate),
+ 77: syscalls.Undocumented("ftruncate", Ftruncate),
+ 78: syscalls.Undocumented("getdents", Getdents),
+ 79: syscalls.Undocumented("getcwd", Getcwd),
+ 80: syscalls.Undocumented("chdir", Chdir),
+ 81: syscalls.Undocumented("fchdir", Fchdir),
+ 82: syscalls.Undocumented("rename", Rename),
+ 83: syscalls.Undocumented("mkdir", Mkdir),
+ 84: syscalls.Undocumented("rmdir", Rmdir),
+ 85: syscalls.Undocumented("creat", Creat),
+ 86: syscalls.Undocumented("link", Link),
+ 87: syscalls.Undocumented("link", Unlink),
+ 88: syscalls.Undocumented("symlink", Symlink),
+ 89: syscalls.Undocumented("readlink", Readlink),
+ 90: syscalls.Undocumented("chmod", Chmod),
+ 91: syscalls.Undocumented("fchmod", Fchmod),
+ 92: syscalls.Undocumented("chown", Chown),
+ 93: syscalls.Undocumented("fchown", Fchown),
+ 94: syscalls.Undocumented("lchown", Lchown),
+ 95: syscalls.Undocumented("umask", Umask),
+ 96: syscalls.Undocumented("gettimeofday", Gettimeofday),
+ 97: syscalls.Undocumented("getrlimit", Getrlimit),
+ 98: syscalls.Undocumented("getrusage", Getrusage),
+ 99: syscalls.Undocumented("sysinfo", Sysinfo),
+ 100: syscalls.Undocumented("times", Times),
+ 101: syscalls.Undocumented("ptrace", Ptrace),
+ 102: syscalls.Undocumented("getuid", Getuid),
+ 103: syscalls.Undocumented("syslog", Syslog),
+ 104: syscalls.Undocumented("getgid", Getgid),
+ 105: syscalls.Undocumented("setuid", Setuid),
+ 106: syscalls.Undocumented("setgid", Setgid),
+ 107: syscalls.Undocumented("geteuid", Geteuid),
+ 108: syscalls.Undocumented("getegid", Getegid),
+ 109: syscalls.Undocumented("setpgid", Setpgid),
+ 110: syscalls.Undocumented("getppid", Getppid),
+ 111: syscalls.Undocumented("getpgrp", Getpgrp),
+ 112: syscalls.Undocumented("setsid", Setsid),
+ 113: syscalls.Undocumented("setreuid", Setreuid),
+ 114: syscalls.Undocumented("setregid", Setregid),
+ 115: syscalls.Undocumented("getgroups", Getgroups),
+ 116: syscalls.Undocumented("setgroups", Setgroups),
+ 117: syscalls.Undocumented("setresuid", Setresuid),
+ 118: syscalls.Undocumented("getresuid", Getresuid),
+ 119: syscalls.Undocumented("setresgid", Setresgid),
+ 120: syscalls.Undocumented("setresgid", Getresgid),
+ 121: syscalls.Undocumented("getpgid", Getpgid),
+ 122: syscalls.ErrorWithEvent("setfsuid", syscall.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 123: syscalls.ErrorWithEvent("setfsgid", syscall.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 124: syscalls.Undocumented("getsid", Getsid),
+ 125: syscalls.Undocumented("capget", Capget),
+ 126: syscalls.Undocumented("capset", Capset),
+ 127: syscalls.Undocumented("rt_sigpending", RtSigpending),
+ 128: syscalls.Undocumented("rt_sigtimedwait", RtSigtimedwait),
+ 129: syscalls.Undocumented("rt_sigqueueinfo", RtSigqueueinfo),
+ 130: syscalls.Undocumented("rt_sigsuspend", RtSigsuspend),
+ 131: syscalls.Undocumented("sigaltstack", Sigaltstack),
+ 132: syscalls.Undocumented("utime", Utime),
+ 133: syscalls.Undocumented("mknod", Mknod),
+ 134: syscalls.Error("uselib", syscall.ENOSYS, "Obsolete", nil),
+ 135: syscalls.ErrorWithEvent("personality", syscall.EINVAL, "Unable to change personality.", nil),
+ 136: syscalls.ErrorWithEvent("ustat", syscall.ENOSYS, "Needs filesystem support.", nil),
+ 137: syscalls.Undocumented("statfs", Statfs),
+ 138: syscalls.Undocumented("fstatfs", Fstatfs),
+ 139: syscalls.ErrorWithEvent("sysfs", syscall.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
+ 140: syscalls.Undocumented("getpriority", Getpriority),
+ 141: syscalls.Undocumented("setpriority", Setpriority),
+ 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 143: syscalls.Undocumented("sched_getparam", SchedGetparam),
+ 144: syscalls.Undocumented("sched_setscheduler", SchedSetscheduler),
+ 145: syscalls.Undocumented("sched_getscheduler", SchedGetscheduler),
+ 146: syscalls.Undocumented("sched_get_priority_max", SchedGetPriorityMax),
+ 147: syscalls.Undocumented("sched_get_priority_min", SchedGetPriorityMin),
+ 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syscall.EPERM, "", nil),
+ 149: syscalls.Undocumented("mlock", Mlock),
+ 150: syscalls.Undocumented("munlock", Munlock),
+ 151: syscalls.Undocumented("mlockall", Mlockall),
+ 152: syscalls.Undocumented("munlockall", Munlockall),
+ 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 154: syscalls.Error("modify_ldt", syscall.EPERM, "", nil),
+ 155: syscalls.Error("pivot_root", syscall.EPERM, "", nil),
+ 156: syscalls.Error("sysctl", syscall.EPERM, `syscall is "worthless"`, nil),
+ 157: syscalls.Undocumented("prctl", Prctl),
+ 158: syscalls.Undocumented("arch_prctl", ArchPrctl),
+ 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 160: syscalls.Undocumented("setrlimit", Setrlimit),
+ 161: syscalls.Undocumented("chroot", Chroot),
+ 162: syscalls.Undocumented("sync", Sync),
+ 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 165: syscalls.Undocumented("mount", Mount),
+ 166: syscalls.Undocumented("umount2", Umount2),
+ 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 170: syscalls.Undocumented("sethostname", Sethostname),
+ 171: syscalls.Undocumented("setdomainname", Setdomainname),
+ 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
+ 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
+ 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
+ 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 177: syscalls.Error("get_kernel_syms", syscall.ENOSYS, "Not supported in > 2.6", nil),
+ 178: syscalls.Error("query_module", syscall.ENOSYS, "Not supported in > 2.6", nil),
+ 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 180: syscalls.Error("nfsservctl", syscall.ENOSYS, "Does not exist > 3.1", nil),
+ 181: syscalls.Error("getpmsg", syscall.ENOSYS, "Not implemented in Linux", nil),
+ 182: syscalls.Error("putpmsg", syscall.ENOSYS, "Not implemented in Linux", nil),
+ 183: syscalls.Error("afs_syscall", syscall.ENOSYS, "Not implemented in Linux", nil),
+ 184: syscalls.Error("tuxcall", syscall.ENOSYS, "Not implemented in Linux", nil),
+ 185: syscalls.Error("security", syscall.ENOSYS, "Not implemented in Linux", nil),
+ 186: syscalls.Undocumented("gettid", Gettid),
+ 187: syscalls.ErrorWithEvent("readahead", syscall.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 188: syscalls.ErrorWithEvent("setxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 189: syscalls.ErrorWithEvent("lsetxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 190: syscalls.ErrorWithEvent("fsetxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 191: syscalls.ErrorWithEvent("getxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 192: syscalls.ErrorWithEvent("lgetxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 193: syscalls.ErrorWithEvent("fgetxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 194: syscalls.ErrorWithEvent("listxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 195: syscalls.ErrorWithEvent("llistxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 196: syscalls.ErrorWithEvent("flistxattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 197: syscalls.ErrorWithEvent("removexattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 198: syscalls.ErrorWithEvent("lremovexattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 199: syscalls.ErrorWithEvent("fremovexattr", syscall.ENOTSUP, "Requires filesystem support", nil),
+ 200: syscalls.Undocumented("tkill", Tkill),
+ 201: syscalls.Undocumented("time", Time),
+ 202: syscalls.Undocumented("futex", Futex),
+ 203: syscalls.Undocumented("sched_setaffinity", SchedSetaffinity),
+ 204: syscalls.Undocumented("sched_getaffinity", SchedGetaffinity),
+ 205: syscalls.Error("set_thread_area", syscall.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 206: syscalls.Undocumented("io_setup", IoSetup),
+ 207: syscalls.Undocumented("io_destroy", IoDestroy),
+ 208: syscalls.Undocumented("io_getevents", IoGetevents),
+ 209: syscalls.Undocumented("io_submit", IoSubmit),
+ 210: syscalls.Undocumented("io_cancel", IoCancel),
+ 211: syscalls.Error("get_thread_area", syscall.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 213: syscalls.Undocumented("epoll_create", EpollCreate),
+ 214: syscalls.ErrorWithEvent("epoll_ctl_old", syscall.ENOSYS, "Deprecated", nil),
+ 215: syscalls.ErrorWithEvent("epoll_wait_old", syscall.ENOSYS, "Deprecated", nil),
+ 216: syscalls.ErrorWithEvent("remap_file_pages", syscall.ENOSYS, "Deprecated since 3.16", nil),
+ 217: syscalls.Undocumented("getdents64", Getdents64),
+ 218: syscalls.Undocumented("set_tid_address", SetTidAddress),
+ 219: syscalls.Undocumented("restart_syscall", RestartSyscall),
+ 220: syscalls.ErrorWithEvent("semtimedop", syscall.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 221: syscalls.Undocumented("fadvise64", Fadvise64),
+ 222: syscalls.Undocumented("timer_create", TimerCreate),
+ 223: syscalls.Undocumented("timer_settime", TimerSettime),
+ 224: syscalls.Undocumented("timer_gettime", TimerGettime),
+ 225: syscalls.Undocumented("timer_getoverrun", TimerGetoverrun),
+ 226: syscalls.Undocumented("timer_delete", TimerDelete),
+ 227: syscalls.Undocumented("clock_settime", ClockSettime),
+ 228: syscalls.Undocumented("clock_gettime", ClockGettime),
+ 229: syscalls.Undocumented("clock_getres", ClockGetres),
+ 230: syscalls.Undocumented("clock_nanosleep", ClockNanosleep),
+ 231: syscalls.Undocumented("exit_group", ExitGroup),
+ 232: syscalls.Undocumented("epoll_wait", EpollWait),
+ 233: syscalls.Undocumented("epoll_ctl", EpollCtl),
+ 234: syscalls.Undocumented("tgkill", Tgkill),
+ 235: syscalls.Undocumented("utimes", Utimes),
+ 236: syscalls.Error("vserver", syscall.ENOSYS, "Not implemented by Linux", nil),
+ 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 238: syscalls.Undocumented("set_mempolicy", SetMempolicy),
+ 239: syscalls.Undocumented("get_mempolicy", GetMempolicy),
+ 240: syscalls.ErrorWithEvent("mq_open", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 241: syscalls.ErrorWithEvent("mq_unlink", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 242: syscalls.ErrorWithEvent("mq_timedsend", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 243: syscalls.ErrorWithEvent("mq_timedreceive", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 244: syscalls.ErrorWithEvent("mq_notify", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 245: syscalls.ErrorWithEvent("mq_getsetattr", syscall.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 247: syscalls.Undocumented("waitid", Waitid),
+ 248: syscalls.Error("add_key", syscall.EACCES, "Not available to user", nil),
+ 249: syscalls.Error("request_key", syscall.EACCES, "Not available to user", nil),
+ 250: syscalls.Error("keyctl", syscall.EACCES, "Not available to user", nil),
+ 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 253: syscalls.Undocumented("inotify_init", InotifyInit),
+ 254: syscalls.Undocumented("inotify_add_watch", InotifyAddWatch),
+ 255: syscalls.Undocumented("inotify_rm_watch", InotifyRmWatch),
+ 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 257: syscalls.Undocumented("openat", Openat),
+ 258: syscalls.Undocumented("mkdirat", Mkdirat),
+ 259: syscalls.Undocumented("mknodat", Mknodat),
+ 260: syscalls.Undocumented("fchownat", Fchownat),
+ 261: syscalls.Undocumented("futimesat", Futimesat),
+ 262: syscalls.Undocumented("fstatat", Fstatat),
+ 263: syscalls.Undocumented("unlinkat", Unlinkat),
+ 264: syscalls.Undocumented("renameat", Renameat),
+ 265: syscalls.Undocumented("linkat", Linkat),
+ 266: syscalls.Undocumented("symlinkat", Symlinkat),
+ 267: syscalls.Undocumented("readlinkat", Readlinkat),
+ 268: syscalls.Undocumented("fchmodat", Fchmodat),
+ 269: syscalls.Undocumented("faccessat", Faccessat),
+ 270: syscalls.Undocumented("pselect", Pselect),
+ 271: syscalls.Undocumented("ppoll", Ppoll),
+ 272: syscalls.Undocumented("unshare", Unshare),
+ 273: syscalls.Error("set_robust_list", syscall.ENOSYS, "Obsolete", nil),
+ 274: syscalls.Error("get_robust_list", syscall.ENOSYS, "Obsolete", nil),
+ 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 276: syscalls.ErrorWithEvent("tee", syscall.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 277: syscalls.Undocumented("sync_file_range", SyncFileRange),
+ 278: syscalls.ErrorWithEvent("vmsplice", syscall.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 280: syscalls.Undocumented("utimensat", Utimensat),
+ 281: syscalls.Undocumented("epoll_pwait", EpollPwait),
+ 282: syscalls.ErrorWithEvent("signalfd", syscall.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 283: syscalls.Undocumented("timerfd_create", TimerfdCreate),
+ 284: syscalls.Undocumented("eventfd", Eventfd),
+ 285: syscalls.Undocumented("fallocate", Fallocate),
+ 286: syscalls.Undocumented("timerfd_settime", TimerfdSettime),
+ 287: syscalls.Undocumented("timerfd_gettime", TimerfdGettime),
+ 288: syscalls.Undocumented("accept4", Accept4),
+ 289: syscalls.ErrorWithEvent("signalfd4", syscall.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 290: syscalls.Undocumented("eventfd2", Eventfd2),
+ 291: syscalls.Undocumented("epoll_create1", EpollCreate1),
+ 292: syscalls.Undocumented("dup3", Dup3),
+ 293: syscalls.Undocumented("pipe2", Pipe2),
+ 294: syscalls.Undocumented("inotify_init1", InotifyInit1),
+ 295: syscalls.Undocumented("preadv", Preadv),
+ 296: syscalls.Undocumented("pwritev", Pwritev),
+ 297: syscalls.Undocumented("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 298: syscalls.ErrorWithEvent("perf_event_open", syscall.ENODEV, "No support for perf counters", nil),
+ 299: syscalls.Undocumented("recvmmsg", RecvMMsg),
+ 300: syscalls.ErrorWithEvent("fanotify_init", syscall.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 301: syscalls.ErrorWithEvent("fanotify_mark", syscall.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 302: syscalls.Undocumented("prlimit64", Prlimit64),
+ 303: syscalls.ErrorWithEvent("name_to_handle_at", syscall.EOPNOTSUPP, "Needs filesystem support", nil),
+ 304: syscalls.ErrorWithEvent("open_by_handle_at", syscall.EOPNOTSUPP, "Needs filesystem support", nil),
+ 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 306: syscalls.Undocumented("syncfs", Syncfs),
+ 307: syscalls.Undocumented("sendmmsg", SendMMsg),
+ 308: syscalls.ErrorWithEvent("setns", syscall.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 309: syscalls.Undocumented("getcpu", Getcpu),
+ 310: syscalls.ErrorWithEvent("process_vm_readv", syscall.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 311: syscalls.ErrorWithEvent("process_vm_writev", syscall.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 314: syscalls.ErrorWithEvent("sched_setattr", syscall.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 315: syscalls.ErrorWithEvent("sched_getattr", syscall.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 316: syscalls.ErrorWithEvent("renameat2", syscall.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 317: syscalls.Undocumented("seccomp", Seccomp),
+ 318: syscalls.Undocumented("getrandom", GetRandom),
+ 319: syscalls.Undocumented("memfd_create", MemfdCreate),
+ 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
+ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 322: syscalls.ErrorWithEvent("execveat", syscall.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 323: syscalls.ErrorWithEvent("userfaultfd", syscall.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 324: syscalls.ErrorWithEvent("membarrier", syscall.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
+ 325: syscalls.Undocumented("mlock2", Mlock2),
+
// Syscalls after 325 are "backports" from versions of Linux after 4.4.
- // 326: @Syscall(CopyFileRange),
- 327: Preadv2,
- 328: Pwritev2,
+ 326: syscalls.ErrorWithEvent("copy_file_range", syscall.ENOSYS, "", nil),
+ 327: syscalls.Undocumented("preadv2", Preadv2),
+ 328: syscalls.Undocumented("pwritev2", Pwritev2),
},
Emulate: map[usermem.Addr]uintptr{
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
new file mode 100644
index 000000000..652b2c206
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -0,0 +1,312 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+// We unconditionally report a single NUMA node. This also means that our
+// "nodemask_t" is a single unsigned long (uint64).
+const (
+ maxNodes = 1
+ allowedNodemask = (1 << maxNodes) - 1
+)
+
+func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) {
+ // "nodemask points to a bit mask of node IDs that contains up to maxnode
+ // bits. The bit mask size is rounded to the next multiple of
+ // sizeof(unsigned long), but the kernel will use bits only up to maxnode.
+ // A NULL value of nodemask or a maxnode value of zero specifies the empty
+ // set of nodes. If the value of maxnode is zero, the nodemask argument is
+ // ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate
+ // because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses
+ // maxnode-1, not maxnode, as the number of bits.
+ bits := maxnode - 1
+ if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ return 0, syserror.EINVAL
+ }
+ if bits == 0 {
+ return 0, nil
+ }
+ // Copy in the whole nodemask.
+ numUint64 := (bits + 63) / 64
+ buf := t.CopyScratchBuffer(int(numUint64) * 8)
+ if _, err := t.CopyInBytes(addr, buf); err != nil {
+ return 0, err
+ }
+ val := usermem.ByteOrder.Uint64(buf)
+ // Check that only allowed bits in the first unsigned long in the nodemask
+ // are set.
+ if val&^allowedNodemask != 0 {
+ return 0, syserror.EINVAL
+ }
+ // Check that all remaining bits in the nodemask are 0.
+ for i := 8; i < len(buf); i++ {
+ if buf[i] != 0 {
+ return 0, syserror.EINVAL
+ }
+ }
+ return val, nil
+}
+
+func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error {
+ // mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of
+ // bits.
+ bits := maxnode - 1
+ if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0
+ return syserror.EINVAL
+ }
+ if bits == 0 {
+ return nil
+ }
+ // Copy out the first unsigned long in the nodemask.
+ buf := t.CopyScratchBuffer(8)
+ usermem.ByteOrder.PutUint64(buf, val)
+ if _, err := t.CopyOutBytes(addr, buf); err != nil {
+ return err
+ }
+ // Zero out remaining unsigned longs in the nodemask.
+ if bits > 64 {
+ remAddr, ok := addr.AddLength(8)
+ if !ok {
+ return syserror.EFAULT
+ }
+ remUint64 := (bits - 1) / 64
+ if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// GetMempolicy implements the syscall get_mempolicy(2).
+func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ mode := args[0].Pointer()
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+ addr := args[3].Pointer()
+ flags := args[4].Uint()
+
+ if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ nodeFlag := flags&linux.MPOL_F_NODE != 0
+ addrFlag := flags&linux.MPOL_F_ADDR != 0
+ memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
+
+ // "EINVAL: The value specified by maxnode is less than the number of node
+ // IDs supported by the system." - get_mempolicy(2)
+ if nodemask != 0 && maxnode < maxNodes {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is
+ // ignored and the set of nodes (memories) that the thread is allowed to
+ // specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the
+ // absence of any mode flags) is returned in nodemask."
+ if memsAllowed {
+ // "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either
+ // MPOL_F_ADDR or MPOL_F_NODE."
+ if nodeFlag || addrFlag {
+ return 0, nil, syserror.EINVAL
+ }
+ if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, nil
+ }
+
+ // "If flags specifies MPOL_F_ADDR, then information is returned about the
+ // policy governing the memory address given in addr. ... If the mode
+ // argument is not NULL, then get_mempolicy() will store the policy mode
+ // and any optional mode flags of the requested NUMA policy in the location
+ // pointed to by this argument. If nodemask is not NULL, then the nodemask
+ // associated with the policy will be stored in the location pointed to by
+ // this argument."
+ if addrFlag {
+ policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if nodeFlag {
+ // "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
+ // get_mempolicy() will return the node ID of the node on which the
+ // address addr is allocated into the location pointed to by mode.
+ // If no page has yet been allocated for the specified address,
+ // get_mempolicy() will allocate a page as if the thread had
+ // performed a read (load) access to that address, and return the
+ // ID of the node where that page was allocated."
+ buf := t.CopyScratchBuffer(1)
+ _, err := t.CopyInBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ policy = 0 // maxNodes == 1
+ }
+ if mode != 0 {
+ if _, err := t.CopyOut(mode, policy); err != nil {
+ return 0, nil, err
+ }
+ }
+ if nodemask != 0 {
+ if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+ }
+
+ // "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did
+ // not specify MPOL_F_ADDR and addr is not NULL." This is partially
+ // inaccurate: if flags specifies MPOL_F_ADDR,
+ // mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will
+ // just (usually) fail to find a VMA at address 0 and return EFAULT.
+ if addr != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If flags is specified as 0, then information about the calling thread's
+ // default policy (as set by set_mempolicy(2)) is returned, in the buffers
+ // pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but
+ // not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE,
+ // then get_mempolicy() will return in the location pointed to by a
+ // non-NULL mode argument, the node ID of the next node that will be used
+ // for interleaving of internal kernel pages allocated on behalf of the
+ // thread."
+ policy, nodemaskVal := t.NumaPolicy()
+ if nodeFlag {
+ if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
+ return 0, nil, syserror.EINVAL
+ }
+ policy = 0 // maxNodes == 1
+ }
+ if mode != 0 {
+ if _, err := t.CopyOut(mode, policy); err != nil {
+ return 0, nil, err
+ }
+ }
+ if nodemask != 0 {
+ if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// SetMempolicy implements the syscall set_mempolicy(2).
+func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ modeWithFlags := args[0].Int()
+ nodemask := args[1].Pointer()
+ maxnode := args[2].Uint()
+
+ modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ t.SetNumaPolicy(modeWithFlags, nodemaskVal)
+ return 0, nil, nil
+}
+
+// Mbind implements the syscall mbind(2).
+func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint64()
+ mode := args[2].Int()
+ nodemask := args[3].Pointer()
+ maxnode := args[4].Uint()
+ flags := args[5].Uint()
+
+ if flags&^linux.MPOL_MF_VALID != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be
+ // privileged (CAP_SYS_NICE) to use this flag." - mbind(2)
+ if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) {
+ return 0, nil, syserror.EPERM
+ }
+
+ mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Since we claim to have only a single node, all flags can be ignored
+ // (since all pages must already be on that single node).
+ err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal)
+ return 0, nil, err
+}
+
+func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags int32, nodemask usermem.Addr, maxnode uint32) (int32, uint64, error) {
+ flags := modeWithFlags & linux.MPOL_MODE_FLAGS
+ mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS
+ if flags == linux.MPOL_MODE_FLAGS {
+ // Can't specify both mode flags simultaneously.
+ return 0, 0, syserror.EINVAL
+ }
+ if mode < 0 || mode >= linux.MPOL_MAX {
+ // Must specify a valid mode.
+ return 0, 0, syserror.EINVAL
+ }
+
+ var nodemaskVal uint64
+ if nodemask != 0 {
+ var err error
+ nodemaskVal, err = copyInNodemask(t, nodemask, maxnode)
+ if err != nil {
+ return 0, 0, err
+ }
+ }
+
+ switch mode {
+ case linux.MPOL_DEFAULT:
+ // "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate;
+ // Linux allows a nodemask to be specified, as long as it is empty.
+ if nodemaskVal != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_BIND, linux.MPOL_INTERLEAVE:
+ // These require a non-empty nodemask.
+ if nodemaskVal == 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_PREFERRED:
+ // This permits an empty nodemask, as long as no flags are set.
+ if nodemaskVal == 0 && flags != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ case linux.MPOL_LOCAL:
+ // This requires an empty nodemask and no flags set ...
+ if nodemaskVal != 0 || flags != 0 {
+ return 0, 0, syserror.EINVAL
+ }
+ // ... and is implemented as MPOL_PREFERRED.
+ mode = linux.MPOL_PREFERRED
+ default:
+ // Unknown mode, which we should have rejected above.
+ panic(fmt.Sprintf("unknown mode: %v", mode))
+ }
+
+ return mode | flags, nodemaskVal, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 64a6e639c..9926f0ac5 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -204,151 +204,6 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
}
-func copyOutIfNotNull(t *kernel.Task, ptr usermem.Addr, val interface{}) (int, error) {
- if ptr != 0 {
- return t.CopyOut(ptr, val)
- }
- return 0, nil
-}
-
-// GetMempolicy implements the syscall get_mempolicy(2).
-func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- mode := args[0].Pointer()
- nodemask := args[1].Pointer()
- maxnode := args[2].Uint()
- addr := args[3].Pointer()
- flags := args[4].Uint()
-
- memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
- nodeFlag := flags&linux.MPOL_F_NODE != 0
- addrFlag := flags&linux.MPOL_F_ADDR != 0
-
- // TODO(rahat): Once sysfs is implemented, report a single numa node in
- // /sys/devices/system/node.
- if nodemask != 0 && maxnode < 1 {
- return 0, nil, syserror.EINVAL
- }
-
- // 'addr' provided iff 'addrFlag' set.
- if addrFlag == (addr == 0) {
- return 0, nil, syserror.EINVAL
- }
-
- // Default policy for the thread.
- if flags == 0 {
- policy, nodemaskVal := t.NumaPolicy()
- if _, err := copyOutIfNotNull(t, mode, policy); err != nil {
- return 0, nil, syserror.EFAULT
- }
- if _, err := copyOutIfNotNull(t, nodemask, nodemaskVal); err != nil {
- return 0, nil, syserror.EFAULT
- }
- return 0, nil, nil
- }
-
- // Report all nodes available to caller.
- if memsAllowed {
- // MPOL_F_NODE and MPOL_F_ADDR not allowed with MPOL_F_MEMS_ALLOWED.
- if nodeFlag || addrFlag {
- return 0, nil, syserror.EINVAL
- }
-
- // Report a single numa node.
- if _, err := copyOutIfNotNull(t, nodemask, uint32(0x1)); err != nil {
- return 0, nil, syserror.EFAULT
- }
- return 0, nil, nil
- }
-
- if addrFlag {
- if nodeFlag {
- // Return the id for the node where 'addr' resides, via 'mode'.
- //
- // The real get_mempolicy(2) allocates the page referenced by 'addr'
- // by simulating a read, if it is unallocated before the call. It
- // then returns the node the page is allocated on through the mode
- // pointer.
- b := t.CopyScratchBuffer(1)
- _, err := t.CopyInBytes(addr, b)
- if err != nil {
- return 0, nil, syserror.EFAULT
- }
- if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil {
- return 0, nil, syserror.EFAULT
- }
- } else {
- storedPolicy, _ := t.NumaPolicy()
- // Return the policy governing the memory referenced by 'addr'.
- if _, err := copyOutIfNotNull(t, mode, int32(storedPolicy)); err != nil {
- return 0, nil, syserror.EFAULT
- }
- }
- return 0, nil, nil
- }
-
- storedPolicy, _ := t.NumaPolicy()
- if nodeFlag && (storedPolicy&^linux.MPOL_MODE_FLAGS == linux.MPOL_INTERLEAVE) {
- // Policy for current thread is to interleave memory between
- // nodes. Return the next node we'll allocate on. Since we only have a
- // single node, this is always node 0.
- if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil {
- return 0, nil, syserror.EFAULT
- }
- return 0, nil, nil
- }
-
- return 0, nil, syserror.EINVAL
-}
-
-func allowedNodesMask() uint32 {
- const maxNodes = 1
- return ^uint32((1 << maxNodes) - 1)
-}
-
-// SetMempolicy implements the syscall set_mempolicy(2).
-func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- modeWithFlags := args[0].Int()
- nodemask := args[1].Pointer()
- maxnode := args[2].Uint()
-
- if nodemask != 0 && maxnode < 1 {
- return 0, nil, syserror.EINVAL
- }
-
- if modeWithFlags&linux.MPOL_MODE_FLAGS == linux.MPOL_MODE_FLAGS {
- // Can't specify multiple modes simultaneously.
- return 0, nil, syserror.EINVAL
- }
-
- mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS
- if mode < 0 || mode >= linux.MPOL_MAX {
- // Must specify a valid mode.
- return 0, nil, syserror.EINVAL
- }
-
- var nodemaskVal uint32
- // Nodemask may be empty for some policy modes.
- if nodemask != 0 && maxnode > 0 {
- if _, err := t.CopyIn(nodemask, &nodemaskVal); err != nil {
- return 0, nil, syserror.EFAULT
- }
- }
-
- if (mode == linux.MPOL_INTERLEAVE || mode == linux.MPOL_BIND) && nodemaskVal == 0 {
- // Mode requires a non-empty nodemask, but got an empty nodemask.
- return 0, nil, syserror.EINVAL
- }
-
- if nodemaskVal&allowedNodesMask() != 0 {
- // Invalid node specified.
- return 0, nil, syserror.EINVAL
- }
-
- t.SetNumaPolicy(int32(modeWithFlags), nodemaskVal)
-
- return 0, nil, nil
-}
-
// Mincore implements the syscall mincore(2).
func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 117ae1a0e..1b7e5616b 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
@@ -23,6 +24,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/mm"
)
// Prctl implements linux syscall prctl(2).
@@ -44,6 +46,33 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
_, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal()))
return 0, nil, err
+ case linux.PR_GET_DUMPABLE:
+ d := t.MemoryManager().Dumpability()
+ switch d {
+ case mm.NotDumpable:
+ return linux.SUID_DUMP_DISABLE, nil, nil
+ case mm.UserDumpable:
+ return linux.SUID_DUMP_USER, nil, nil
+ case mm.RootDumpable:
+ return linux.SUID_DUMP_ROOT, nil, nil
+ default:
+ panic(fmt.Sprintf("Unknown dumpability %v", d))
+ }
+
+ case linux.PR_SET_DUMPABLE:
+ var d mm.Dumpability
+ switch args[1].Int() {
+ case linux.SUID_DUMP_DISABLE:
+ d = mm.NotDumpable
+ case linux.SUID_DUMP_USER:
+ d = mm.UserDumpable
+ default:
+ // N.B. Userspace may not pass SUID_DUMP_ROOT.
+ return 0, nil, syscall.EINVAL
+ }
+ t.MemoryManager().SetDumpability(d)
+ return 0, nil, nil
+
case linux.PR_GET_KEEPCAPS:
if t.Credentials().KeepCaps {
return 1, nil, nil
@@ -171,9 +200,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return 0, nil, t.DropBoundingCapability(cp)
- case linux.PR_GET_DUMPABLE,
- linux.PR_SET_DUMPABLE,
- linux.PR_GET_TIMING,
+ case linux.PR_GET_TIMING,
linux.PR_SET_TIMING,
linux.PR_GET_TSC,
linux.PR_SET_TSC,
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 8f4dbf3bc..31295a6a9 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -188,7 +188,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Create the new socket.
- s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol)
+ s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol)
if e != nil {
return 0, nil, e.ToError()
}
@@ -227,7 +227,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Create the socket pair.
- s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol)
+ s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol)
if e != nil {
return 0, nil, e.ToError()
}
diff --git a/pkg/sentry/syscalls/syscalls.go b/pkg/sentry/syscalls/syscalls.go
index 5d10b3824..48c114232 100644
--- a/pkg/sentry/syscalls/syscalls.go
+++ b/pkg/sentry/syscalls/syscalls.go
@@ -25,37 +25,97 @@
package syscalls
import (
+ "fmt"
+ "syscall"
+
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/sentry/arch"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/syserror"
)
+// Supported returns a syscall that is fully supported.
+func Supported(name string, fn kernel.SyscallFn) kernel.Syscall {
+ return kernel.Syscall{
+ Name: name,
+ Fn: fn,
+ SupportLevel: kernel.SupportFull,
+ Note: "Full Support",
+ }
+}
+
+// Undocumented returns a syscall that is undocumented.
+func Undocumented(name string, fn kernel.SyscallFn) kernel.Syscall {
+ return kernel.Syscall{
+ Name: name,
+ Fn: fn,
+ SupportLevel: kernel.SupportUndocumented,
+ }
+}
+
+// PartiallySupported returns a syscall that has a partial implementation.
+func PartiallySupported(name string, fn kernel.SyscallFn, note string, urls []string) kernel.Syscall {
+ return kernel.Syscall{
+ Name: name,
+ Fn: fn,
+ SupportLevel: kernel.SupportPartial,
+ Note: note,
+ URLs: urls,
+ }
+}
+
// Error returns a syscall handler that will always give the passed error.
-func Error(err error) kernel.SyscallFn {
- return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, err
+func Error(name string, err syscall.Errno, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, err
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q", note, err.Error()),
+ URLs: urls,
}
}
// ErrorWithEvent gives a syscall function that sends an unimplemented
// syscall event via the event channel and returns the passed error.
-func ErrorWithEvent(err error) kernel.SyscallFn {
- return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, err
+func ErrorWithEvent(name string, err syscall.Errno, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, err
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q", note, err.Error()),
+ URLs: urls,
}
}
// CapError gives a syscall function that checks for capability c. If the task
// has the capability, it returns ENOSYS, otherwise EPERM. To unprivileged
// tasks, it will seem like there is an implementation.
-func CapError(c linux.Capability) kernel.SyscallFn {
- return func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- if !t.HasCapability(c) {
- return 0, nil, syserror.EPERM
- }
- t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+func CapError(name string, c linux.Capability, note string, urls []string) kernel.Syscall {
+ if note != "" {
+ note = note + "; "
+ }
+ return kernel.Syscall{
+ Name: name,
+ Fn: func(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if !t.HasCapability(c) {
+ return 0, nil, syserror.EPERM
+ }
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ },
+ SupportLevel: kernel.SupportUnimplemented,
+ Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise", note, syserror.EPERM, c.String(), syserror.ENOSYS),
+ URLs: urls,
}
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index b2f8f6832..b50579a92 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -32,7 +32,7 @@ go_library(
"tsc_arm64.s",
],
importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/time",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/log",
"//pkg/metric",
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index 09198496b..860733061 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -17,6 +17,6 @@ go_library(
],
deps = [
"//pkg/bits",
- "//pkg/sentry/memutil",
+ "//pkg/memutil",
],
)
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index c316f1597..9ed974ccb 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -22,7 +22,7 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/bits"
- "gvisor.googlesource.com/gvisor/pkg/sentry/memutil"
+ "gvisor.googlesource.com/gvisor/pkg/memutil"
)
// MemoryKind represents a type of memory used by the application.
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 31e4d6ada..9dde327a2 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -222,9 +222,11 @@ func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts
return int(r.Addr - addr), nil
}
-// copyStringIncrement is the maximum number of bytes that are copied from
-// virtual memory at a time by CopyStringIn.
-const copyStringIncrement = 64
+// CopyStringIn tuning parameters, defined outside that function for tests.
+const (
+ copyStringIncrement = 64
+ copyStringMaxInitBufLen = 256
+)
// CopyStringIn copies a NUL-terminated string of unknown length from the
// memory mapped at addr in uio and returns it as a string (not including the
@@ -234,31 +236,38 @@ const copyStringIncrement = 64
//
// Preconditions: As for IO.CopyFromUser. maxlen >= 0.
func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) {
- buf := make([]byte, maxlen)
+ initLen := maxlen
+ if initLen > copyStringMaxInitBufLen {
+ initLen = copyStringMaxInitBufLen
+ }
+ buf := make([]byte, initLen)
var done int
for done < maxlen {
- start, ok := addr.AddLength(uint64(done))
- if !ok {
- // Last page of kernel memory. The application can't use this
- // anyway.
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
- }
// Read up to copyStringIncrement bytes at a time.
readlen := copyStringIncrement
if readlen > maxlen-done {
readlen = maxlen - done
}
- end, ok := start.AddLength(uint64(readlen))
+ end, ok := addr.AddLength(uint64(readlen))
if !ok {
return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
// copies up to the end of application-mappable memory succeed.
- if start.RoundDown() != end.RoundDown() {
+ if addr.RoundDown() != end.RoundDown() {
end = end.RoundDown()
+ readlen = int(end - addr)
+ }
+ // Ensure that our buffer is large enough to accommodate the read.
+ if done+readlen > len(buf) {
+ newBufLen := len(buf) * 2
+ if newBufLen > maxlen {
+ newBufLen = maxlen
+ }
+ buf = append(buf, make([]byte, newBufLen-len(buf))...)
}
- n, err := uio.CopyIn(ctx, start, buf[done:done+int(end-start)], opts)
+ n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
for i, c := range buf[done : done+n] {
@@ -270,6 +279,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
}
+ addr = end
}
return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go
index 4a07118b7..575e5039d 100644
--- a/pkg/sentry/usermem/usermem_test.go
+++ b/pkg/sentry/usermem/usermem_test.go
@@ -192,6 +192,7 @@ func TestCopyObject(t *testing.T) {
}
func TestCopyStringInShort(t *testing.T) {
+ // Tests for string length <= copyStringIncrement.
want := strings.Repeat("A", copyStringIncrement-2)
mem := want + "\x00"
if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
@@ -200,13 +201,25 @@ func TestCopyStringInShort(t *testing.T) {
}
func TestCopyStringInLong(t *testing.T) {
- want := strings.Repeat("A", copyStringIncrement+1)
+ // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
+ // (requiring multiple calls to IO.CopyIn()).
+ want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
mem := want + "\x00"
if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
}
}
+func TestCopyStringInVeryLong(t *testing.T) {
+ // Tests for string length > copyStringMaxInitBufLen (requiring buffer
+ // reallocation).
+ want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
+ mem := want + "\x00"
+ if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
+ }
+}
+
func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
want := strings.Repeat("A", copyStringIncrement-1)
got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 1f889c2a0..b88e2e7bf 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -21,12 +21,29 @@
// FD based endpoints can be used in the networking stack by calling New() to
// create a new endpoint, and then passing it as an argument to
// Stack.CreateNIC().
+//
+// FD based endpoints can use more than one file descriptor to read incoming
+// packets. If there are more than one FDs specified and the underlying FD is an
+// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the
+// host kernel will consistently hash the packets to the sockets. This ensures
+// that packets for the same TCP streams are not reordered.
+//
+// Similarly if more than one FD's are specified where the underlying FD is not
+// AF_PACKET then it's the caller's responsibility to ensure that all inbound
+// packets on the descriptors are consistently 5 tuple hashed to one of the
+// descriptors to prevent TCP reordering.
+//
+// Since netstack today does not compute 5 tuple hashes for outgoing packets we
+// only use the first FD to write outbound packets. Once 5 tuple hashes for
+// all outbound packets are available we will make use of all underlying FD's to
+// write outbound packets.
package fdbased
import (
"fmt"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
@@ -65,8 +82,10 @@ const (
)
type endpoint struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
+ // fds is the set of file descriptors each identifying one inbound/outbound
+ // channel. The endpoint will dispatch from all inbound channels as well as
+ // hash outbound packets to specific channels based on the packet hash.
+ fds []int
// mtu (maximum transmission unit) is the maximum size of a packet.
mtu uint32
@@ -85,8 +104,8 @@ type endpoint struct {
// its end of the communication pipe.
closed func(*tcpip.Error)
- inboundDispatcher linkDispatcher
- dispatcher stack.NetworkDispatcher
+ inboundDispatchers []linkDispatcher
+ dispatcher stack.NetworkDispatcher
// packetDispatchMode controls the packet dispatcher used by this
// endpoint.
@@ -99,17 +118,47 @@ type endpoint struct {
// Options specify the details about the fd-based endpoint to be created.
type Options struct {
- FD int
- MTU uint32
- EthernetHeader bool
- ClosedFunc func(*tcpip.Error)
- Address tcpip.LinkAddress
- SaveRestore bool
- DisconnectOk bool
- GSOMaxSize uint32
+ // FDs is a set of FDs used to read/write packets.
+ FDs []int
+
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // EthernetHeader if true, indicates that the endpoint should read/write
+ // ethernet frames instead of IP packets.
+ EthernetHeader bool
+
+ // ClosedFunc is a function to be called when an endpoint's peer (if
+ // any) closes its end of the communication pipe.
+ ClosedFunc func(*tcpip.Error)
+
+ // Address is the link address for this endpoint. Only used if
+ // EthernetHeader is true.
+ Address tcpip.LinkAddress
+
+ // SaveRestore if true, indicates that this NIC capability set should
+ // include CapabilitySaveRestore
+ SaveRestore bool
+
+ // DisconnectOk if true, indicates that this NIC capability set should
+ // include CapabilityDisconnectOk.
+ DisconnectOk bool
+
+ // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ GSOMaxSize uint32
+
+ // PacketDispatchMode specifies the type of inbound dispatcher to be
+ // used for this endpoint.
PacketDispatchMode PacketDispatchMode
- TXChecksumOffload bool
- RXChecksumOffload bool
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
}
// New creates a new fd-based endpoint.
@@ -117,10 +166,6 @@ type Options struct {
// Makes fd non-blocking, but does not take ownership of fd, which must remain
// open for the lifetime of the returned endpoint.
func New(opts *Options) (tcpip.LinkEndpointID, error) {
- if err := syscall.SetNonblock(opts.FD, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err)
- }
-
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -144,8 +189,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
caps |= stack.CapabilityDisconnectOk
}
+ if len(opts.FDs) == 0 {
+ return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ }
+
e := &endpoint{
- fd: opts.FD,
+ fds: opts.FDs,
mtu: opts.MTU,
caps: caps,
closed: opts.ClosedFunc,
@@ -154,46 +203,71 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
packetDispatchMode: opts.PacketDispatchMode,
}
- isSocket, err := isSocketFD(e.fd)
- if err != nil {
- return 0, err
- }
- if isSocket {
- if opts.GSOMaxSize != 0 {
- e.caps |= stack.CapabilityGSO
- e.gsoMaxSize = opts.GSOMaxSize
+ // Create per channel dispatchers.
+ for i := 0; i < len(e.fds); i++ {
+ fd := e.fds[i]
+ if err := syscall.SetNonblock(fd, true); err != nil {
+ return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
- }
- e.inboundDispatcher, err = createInboundDispatcher(e, isSocket)
- if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+
+ isSocket, err := isSocketFD(fd)
+ if err != nil {
+ return 0, err
+ }
+ if isSocket {
+ if opts.GSOMaxSize != 0 {
+ e.caps |= stack.CapabilityGSO
+ e.gsoMaxSize = opts.GSOMaxSize
+ }
+ }
+ inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
+ if err != nil {
+ return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ }
+ e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
return stack.RegisterLinkEndpoint(e), nil
}
-func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) {
+func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
// By default use the readv() dispatcher as it works with all kinds of
// FDs (tap/tun/unix domain sockets and af_packet).
- inboundDispatcher, err := newReadVDispatcher(e.fd, e)
+ inboundDispatcher, err := newReadVDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err)
}
if isSocket {
+ sa, err := unix.Getsockname(fd)
+ if err != nil {
+ return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err)
+ }
+ switch sa.(type) {
+ case *unix.SockaddrLinklayer:
+ // enable PACKET_FANOUT mode is the underlying socket is
+ // of type AF_PACKET.
+ const fanoutID = 1
+ const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ fanoutArg := fanoutID | fanoutType<<16
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
+ }
+ }
+
switch e.packetDispatchMode {
case PacketMMap:
- inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e)
+ inboundDispatcher, err = newPacketMMapDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err)
}
case RecvMMsg:
// If the provided FD is a socket then we optimize
// packet reads by using recvmmsg() instead of read() to
// read packets in a batch.
- inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e)
+ inboundDispatcher, err = newRecvMMsgDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err)
}
}
}
@@ -215,7 +289,9 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// Link endpoints are not savable. When transportation endpoints are
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
- go e.dispatchLoop() // S/R-SAFE: See above.
+ for i := range e.inboundDispatchers {
+ go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ }
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
@@ -305,26 +381,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
}
- return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView())
+ return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView())
}
if payload.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fd, hdr.View())
+ return rawfile.NonBlockingWrite(e.fds[0], hdr.View())
}
- return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil)
+ return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
}
// WriteRawPacket writes a raw packet directly to the file descriptor.
func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fd, packet)
+ return rawfile.NonBlockingWrite(e.fds[0], packet)
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop() *tcpip.Error {
+func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error {
for {
- cont, err := e.inboundDispatcher.dispatch()
+ cont, err := inboundDispatcher.dispatch()
if err != nil || !cont {
if e.closed != nil {
e.closed(err)
@@ -363,7 +439,7 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti
syscall.SetNonblock(fd, true)
e := &InjectableEndpoint{endpoint: endpoint{
- fd: fd,
+ fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index fd1722074..ba3e09192 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -67,7 +67,7 @@ func newContext(t *testing.T, opt *Options) *context {
done <- struct{}{}
}
- opt.FD = fds[1]
+ opt.FDs = []int{fds[1]}
epID, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index fccabd554..98581e50e 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -118,7 +118,7 @@ func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcp
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, vv.First())
+ logPacket("recv", protocol, vv.First(), nil)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
vs := vv.Views()
@@ -198,7 +198,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
// the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, hdr.View())
+ logPacket("send", protocol, hdr.View(), gso)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
hdrBuf := hdr.View()
@@ -240,7 +240,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -404,5 +404,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
return
}
+ if gso != nil {
+ details += fmt.Sprintf(" gso: %+v", gso)
+ }
+
log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index da07a39e5..44b1d5b9b 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -215,7 +215,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
views[0] = hdr.View()
views = append(views, payload.Views()...)
vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- e.HandlePacket(r, vv)
+ loopedR := r.MakeLoopedRoute()
+ e.HandlePacket(&loopedR, vv)
+ loopedR.Release()
}
if loop&stack.PacketOut == 0 {
return nil
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 4b8cd496b..bcae98e1f 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -108,7 +108,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
views[0] = hdr.View()
views = append(views, payload.Views()...)
vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- e.HandlePacket(r, vv)
+ loopedR := r.MakeLoopedRoute()
+ e.HandlePacket(&loopedR, vv)
+ loopedR.Release()
}
if loop&stack.PacketOut == 0 {
return nil
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 1681de56e..1fa899e7e 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -137,7 +137,7 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu})
+ linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 642607f83..d47085581 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -129,7 +129,7 @@ func main() {
}
linkID, err := fdbased.New(&fdbased.Options{
- FD: fd,
+ FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
Address: tcpip.LinkAddress(maddr),
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 3d4c282a9..55ed02479 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -187,3 +187,13 @@ func (r *Route) Clone() Route {
r.ref.incRef()
return *r
}
+
+// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast.
+func (r *Route) MakeLoopedRoute() Route {
+ l := r.Clone()
+ if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
+ l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
+ l.RemoteLinkAddress = l.LocalLinkAddress
+ }
+ return l
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8d74f1543..e8a9392b5 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s
f.proto.controlCount++
}
+func (f *fakeTransportEndpoint) State() uint32 {
+ return 0
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index f9886c6e4..04c776205 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -377,6 +377,10 @@ type Endpoint interface {
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
}
// WriteOptions contains options for Endpoint.Write.
@@ -468,6 +472,14 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
+// the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption is used to query the supported congestion
+// control algorithms.
+type AvailableCongestionControlOption string
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9aa6f3978..84a2b53b7 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
+ "//pkg/tcpip/transport/tcp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index e2b90ef10..b8005093a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
+
+// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
+// expose internal socket state.
+func (e *endpoint) State() uint32 {
+ return 0
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 1daf5823f..e4ff50c91 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b
ep.waiterQueue.Notify(waiter.EventIn)
}
}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index e31b03f7d..a9dbfb930 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -21,6 +21,7 @@ go_library(
"accept.go",
"connect.go",
"cubic.go",
+ "cubic_state.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
@@ -70,6 +71,7 @@ go_test(
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
+ "tcp_noracedetector_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d4b860975..d05259c0a 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -19,7 +19,6 @@ import (
"encoding/binary"
"hash"
"io"
- "log"
"sync"
"time"
@@ -227,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
}
n.isRegistered = true
- n.state = stateConnecting
// Create sender and receiver.
//
@@ -253,14 +251,15 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newHandshake(ep, l.rcvWnd)
- h.resetToSynRcvd(cookie, irs, opts, l.listenEP)
+ h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
return nil, err
}
-
- ep.state = stateConnected
+ ep.mu.Lock()
+ ep.state = StateEstablished
+ ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -277,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.RLock()
state := e.state
e.mu.RUnlock()
- if state == stateListen {
+ if state == StateListen {
e.acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
@@ -295,7 +294,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
-
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -307,8 +305,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
func (e *endpoint) incSynRcvdCount() bool {
e.mu.Lock()
- log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount)
- if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c {
+ if e.synRcvdCount >= cap(e.acceptedChan) {
e.mu.Unlock()
return false
}
@@ -323,6 +320,16 @@ func (e *endpoint) decSynRcvdCount() {
e.mu.Unlock()
}
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.mu.Lock()
+ if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
+ e.mu.Unlock()
+ return true
+ }
+ e.mu.Unlock()
+ return false
+}
+
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
@@ -330,20 +337,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
case header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
- // Drop the SYN if the listen endpoint's accept queue is
- // overflowing.
- if e.incSynRcvdCount() {
- log.Printf("processing syn packet")
+ // Only handle the syn if the following conditions hold
+ // - accept queue is not full.
+ // - number of connections in synRcvd state is less than the
+ // backlog.
+ if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
return
}
- log.Printf("dropping syn packet")
+ decSynRcvdCount()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
} else {
- // TODO(bhaskerh): Increment syncookie sent stat.
+ // If cookies are in use but the endpoint accept queue
+ // is full then drop the syn.
+ if e.acceptQueueIsFull() {
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN with window scaling because we currently
// dont't encode this information in the cookie.
@@ -361,7 +375,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
case header.TCPFlagAck:
- if len(e.acceptedChan) == cap(e.acceptedChan) {
+ if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
// retransmitted by the sender anyway and we can
@@ -411,7 +425,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- n.state = stateConnected
+ n.state = StateEstablished
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -434,7 +448,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = stateClosed
+ e.state = StateClose
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 2aed6f286..dd671f7ce 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -60,12 +60,11 @@ const (
// handshake holds the state used during a TCP 3-way handshake.
type handshake struct {
- ep *endpoint
- listenEP *endpoint // only non nil when doing passive connects.
- state handshakeState
- active bool
- flags uint8
- ackNum seqnum.Value
+ ep *endpoint
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
iss seqnum.Value
@@ -142,7 +141,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -150,7 +149,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
- h.listenEP = listenEP
+ h.ep.mu.Lock()
+ h.ep.state = StateSynRecv
+ h.ep.mu.Unlock()
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -219,6 +220,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
+ h.ep.mu.Lock()
+ h.ep.state = StateSynRecv
+ h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: rcvSynOpts.TS,
@@ -281,18 +285,6 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
- // listenContext is also used by a tcp.Forwarder and in that
- // context we do not have a listening endpoint to check the
- // backlog. So skip this check if listenEP is nil.
- if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) {
- // If there is no space in the accept queue to accept
- // this endpoint then silently drop this ACK. The peer
- // will anyway resend the ack and we can complete the
- // connection the next time it's retransmitted.
- h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- h.ep.stack.Stats().DroppedPackets.Increment()
- return nil
- }
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -663,7 +655,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -714,8 +706,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
// protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
-
- e.state = stateError
+ e.state = StateError
e.hardError = err
}
@@ -871,14 +862,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// handshake, and then inform potential waiters about its
// completion.
h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ e.mu.Lock()
+ h.ep.state = StateSynSent
+ e.mu.Unlock()
+
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = stateError
+ e.state = StateError
e.hardError = err
+
// Lock released below.
epilogue()
@@ -900,7 +896,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- e.state = stateConnected
+ e.state = StateEstablished
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -1000,7 +996,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
}
- if e.state != stateError {
+ if e.state != StateError {
close(e.drainDone)
<-e.undrain
}
@@ -1056,8 +1052,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != stateError {
- e.state = stateClosed
+ if e.state != StateError {
+ e.state = StateClose
}
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
index e618cd2b9..7b1f5e763 100644
--- a/pkg/tcpip/transport/tcp/cubic.go
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -23,6 +23,7 @@ import (
// control algorithm state.
//
// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
type cubicState struct {
// wLastMax is the previous wMax value.
wLastMax float64
@@ -33,7 +34,7 @@ type cubicState struct {
// t denotes the time when the current congestion avoidance
// was entered.
- t time.Time
+ t time.Time `state:".(unixTime)"`
// numCongestionEvents tracks the number of congestion events since last
// RTO.
diff --git a/pkg/sentry/memutil/memutil.go b/pkg/tcpip/transport/tcp/cubic_state.go
index a4154c42a..d0f58cfaf 100644
--- a/pkg/sentry/memutil/memutil.go
+++ b/pkg/tcpip/transport/tcp/cubic_state.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package memutil contains the utility functions for memory operations.
-package memutil
+package tcp
+
+import (
+ "time"
+)
+
+// saveT is invoked by stateify.
+func (c *cubicState) saveT() unixTime {
+ return unixTime{c.t.Unix(), c.t.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (c *cubicState) loadT(unix unixTime) {
+ c.t = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b66610ee2..1efe9d3fb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -32,18 +33,81 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
-type endpointState int
+// EndpointState represents the state of a TCP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateListen
- stateConnecting
- stateConnected
- stateClosed
- stateError
+ // Endpoint states internal to netstack. These map to the TCP state CLOSED.
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnecting // Connect() called, but the initial SYN hasn't been sent.
+ StateError
+
+ // TCP protocol states.
+ StateEstablished
+ StateSynSent
+ StateSynRecv
+ StateFinWait1
+ StateFinWait2
+ StateTimeWait
+ StateClose
+ StateCloseWait
+ StateLastAck
+ StateListen
+ StateClosing
)
+// connected is the set of states where an endpoint is connected to a peer.
+func (s EndpointState) connected() bool {
+ switch s {
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ return true
+ default:
+ return false
+ }
+}
+
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnecting:
+ return "CONNECTING"
+ case StateError:
+ return "ERROR"
+ case StateEstablished:
+ return "ESTABLISHED"
+ case StateSynSent:
+ return "SYN-SENT"
+ case StateSynRecv:
+ return "SYN-RCVD"
+ case StateFinWait1:
+ return "FIN-WAIT1"
+ case StateFinWait2:
+ return "FIN-WAIT2"
+ case StateTimeWait:
+ return "TIME-WAIT"
+ case StateClose:
+ return "CLOSED"
+ case StateCloseWait:
+ return "CLOSE-WAIT"
+ case StateLastAck:
+ return "LAST-ACK"
+ case StateListen:
+ return "LISTEN"
+ case StateClosing:
+ return "CLOSING"
+ default:
+ panic("unreachable")
+ }
+}
+
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
@@ -108,10 +172,14 @@ type endpoint struct {
rcvBufUsed int
// The following fields are protected by the mutex.
- mu sync.RWMutex `state:"nosave"`
- id stack.TransportEndpointID
- state endpointState `state:".(endpointState)"`
- isPortReserved bool `state:"manual"`
+ mu sync.RWMutex `state:"nosave"`
+ id stack.TransportEndpointID
+
+ // state endpointState `state:".(endpointState)"`
+ // pState ProtocolState
+ state EndpointState `state:".(EndpointState)"`
+
+ isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
route stack.Route `state:"manual"`
@@ -219,7 +287,7 @@ type endpoint struct {
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
- cc CongestionControlOption
+ cc tcpip.CongestionControlOption
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
@@ -304,6 +372,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
+ state: StateInitial,
rcvBufSize: DefaultBufferSize,
sndBufSize: DefaultBufferSize,
sndMTU: int(math.MaxInt32),
@@ -326,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.rcvBufSize = rs.Default
}
- var cs CongestionControlOption
+ var cs tcpip.CongestionControlOption
if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
@@ -335,7 +404,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite
e.probe = p
}
- e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
e.workMu.Lock()
e.tsOffset = timeStampOffset()
@@ -351,14 +420,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
defer e.mu.RUnlock()
switch e.state {
- case stateInitial, stateBound, stateConnecting:
+ case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case stateClosed, stateError:
+ case StateClose, StateError:
// Ready for anything.
result = mask
- case stateListen:
+ case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
if len(e.acceptedChan) > 0 {
@@ -366,7 +435,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
- case stateConnected:
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -427,7 +496,7 @@ func (e *endpoint) Close() {
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == stateListen && e.isPortReserved {
+ if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.isRegistered = false
@@ -487,15 +556,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
e.mu.RLock()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
- // would cause the state to become stateError so we should allow the
+ // would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.hardError
e.mu.RUnlock()
- if s == stateError {
+ if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -511,7 +580,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || e.state != stateConnected {
+ if e.rcvClosed || !e.state.connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -547,9 +616,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
defer e.mu.RUnlock()
// The endpoint cannot be written to if it's not connected.
- if e.state != stateConnected {
+ if !e.state.connected() {
switch e.state {
- case stateError:
+ case StateError:
return 0, nil, e.hardError
default:
return 0, nil, tcpip.ErrClosedForSend
@@ -612,8 +681,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; s != stateConnected && s != stateClosed {
- if s == stateError {
+ if s := e.state; !s.connected() && s != StateClose {
+ if s == StateError {
return 0, tcpip.ControlMessages{}, e.hardError
}
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -623,7 +692,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || e.state != stateConnected {
+ if e.rcvClosed || !e.state.connected() {
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
@@ -757,8 +826,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.segmentQueue.setLimit(2 * size)
-
e.notifyProtocolGoroutine(mask)
return nil
@@ -791,7 +858,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -832,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ // Acquire the work mutex as we may need to
+ // reinitialize the congestion control state.
+ e.mu.Lock()
+ state := e.state
+ e.cc = v
+ e.mu.Unlock()
+ switch state {
+ case StateEstablished:
+ e.workMu.Lock()
+ e.mu.Lock()
+ if e.state == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ }
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
default:
return nil
}
@@ -843,7 +944,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == stateListen {
+ if e.state == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1001,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.CongestionControlOption:
+ e.mu.Lock()
+ *o = e.cc
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1059,7 +1166,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
nicid := addr.NIC
switch e.state {
- case stateBound:
+ case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
if e.boundNICID == 0 {
@@ -1072,16 +1179,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
nicid = e.boundNICID
- case stateInitial:
- // Nothing to do. We'll eventually fill-in the gaps in the ID
- // (if any) when we find a route.
+ case StateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
+ // when we find a route.
- case stateConnecting:
- // A connection request has already been issued but hasn't
- // completed yet.
+ case StateConnecting, StateSynSent, StateSynRecv:
+ // A connection request has already been issued but hasn't completed
+ // yet.
return tcpip.ErrAlreadyConnecting
- case stateConnected:
+ case StateEstablished:
// The endpoint is already connected. If caller hasn't been notified yet, return success.
if !e.isConnectNotified {
e.isConnectNotified = true
@@ -1090,7 +1197,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// Otherwise return that it's already connected.
return tcpip.ErrAlreadyConnected
- case stateError:
+ case StateError:
return e.hardError
default:
@@ -1156,7 +1263,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
}
e.isRegistered = true
- e.state = stateConnecting
+ e.state = StateConnecting
e.route = r.Clone()
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
@@ -1177,7 +1284,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = stateConnected
+ e.state = StateEstablished
}
if run {
@@ -1201,8 +1308,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
defer e.mu.Unlock()
e.shutdownFlags |= flags
- switch e.state {
- case stateConnected:
+ switch {
+ case e.state.connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1243,7 +1350,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.sndCloseWaker.Assert()
}
- case stateListen:
+ case e.state == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1271,7 +1378,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == stateListen && !e.workerCleanup {
+ if e.state == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1290,7 +1397,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != stateBound {
+ if e.state != StateBound {
return tcpip.ErrInvalidEndpointState
}
@@ -1300,7 +1407,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
e.isRegistered = true
- e.state = stateListen
+ e.state = StateListen
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
@@ -1327,7 +1434,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != stateListen {
+ if e.state != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -1355,7 +1462,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -1410,7 +1517,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
return nil
}
@@ -1432,7 +1539,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if !e.state.connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1741,3 +1848,11 @@ func (e *endpoint) initGSO() {
gso.MaxSize = e.route.GSOMaxSize()
e.gso = gso
}
+
+// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
+// state for diagnostics.
+func (e *endpoint) State() uint32 {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 27b0be046..5f30c2374 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() {
defer e.mu.Unlock()
switch e.state {
- case stateInitial, stateBound:
- case stateConnected:
+ case StateInitial, StateBound:
+ case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
@@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case stateListen, stateConnecting:
+ case StateListen, StateConnecting:
e.drainSegmentLocked()
- if e.state != stateClosed && e.state != stateError {
+ if e.state != StateClose && e.state != StateError {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
break
}
fallthrough
- case stateError, stateClosed:
- for e.state == stateError && e.workerRunning {
+ case StateError, StateClose:
+ for e.state == StateError && e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
@@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() {
panic("endpoint still has waiters upon save")
}
- if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) {
+ if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
}
}
@@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
}
// saveState is invoked by stateify.
-func (e *endpoint) saveState() endpointState {
+func (e *endpoint) saveState() EndpointState {
return e.state
}
@@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state endpointState) {
+func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
switch state {
- case stateConnected:
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
connectedLoading.Add(1)
- case stateListen:
+ case StateListen:
listenLoading.Add(1)
- case stateConnecting:
+ case StateConnecting, StateSynSent, StateSynRecv:
connectingLoading.Add(1)
}
e.state = state
@@ -163,12 +163,12 @@ func (e *endpoint) loadState(state endpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
- e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
state := e.state
switch state {
- case stateInitial, stateBound, stateListen, stateConnecting, stateConnected:
+ case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
@@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() {
}
bind := func() {
- e.state = stateInitial
+ e.state = StateInitial
if len(e.bindAddress) == 0 {
e.bindAddress = e.id.LocalAddress
}
@@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() {
}
switch state {
- case stateConnected:
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
// This endpoint is accepted by netstack but not yet by
@@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() {
panic("endpoint connecting failed: " + err.String())
}
connectedLoading.Done()
- case stateListen:
+ case StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() {
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case stateConnecting:
+ case StateConnecting, StateSynSent, StateSynRecv:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case stateBound:
+ case StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() {
bind()
tcpip.AsyncLoading.Done()
}()
- case stateClosed:
+ case StateClose:
if e.isPortReserved {
tcpip.AsyncLoading.Add(1)
go func() {
@@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() {
listenLoading.Wait()
connectingLoading.Wait()
bind()
- e.state = stateClosed
+ e.state = StateClose
tcpip.AsyncLoading.Done()
}()
}
fallthrough
- case stateError:
+ case StateError:
tcpip.DeleteDanglingEndpoint(e)
}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index d31a1edcb..59f4009a1 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -48,6 +48,10 @@ const (
// MaxBufferSize is the largest size a receive and send buffer can grow to.
maxBufferSize = 4 << 20 // 4MB
+
+ // MaxUnprocessedSegments is the maximum number of unprocessed segments
+ // that can be queued for a given endpoint.
+ MaxUnprocessedSegments = 300
)
// SACKEnabled option can be used to enable SACK support in the TCP
@@ -75,13 +79,6 @@ const (
ccCubic = "cubic"
)
-// CongestionControlOption sets the current congestion control algorithm.
-type CongestionControlOption string
-
-// AvailableCongestionControlOption returns the supported congestion control
-// algorithms.
-type AvailableCongestionControlOption string
-
type protocol struct {
mu sync.Mutex
sackEnabled bool
@@ -89,7 +86,6 @@ type protocol struct {
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
- allowedCongestionControl []string
}
// Number returns the tcp protocol number.
@@ -184,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
- case CongestionControlOption:
+ case tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
@@ -193,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
}
- return tcpip.ErrInvalidOptionValue
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -219,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
*v = p.recvBufferSize
p.mu.Unlock()
return nil
- case *CongestionControlOption:
+ case *tcpip.CongestionControlOption:
p.mu.Lock()
- *v = CongestionControlOption(p.congestionControl)
+ *v = tcpip.CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
- case *AvailableCongestionControlOption:
+ case *tcpip.AvailableCongestionControlOption:
p.mu.Lock()
- *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index b08a0e356..f02fa6105 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// sequence numbers that have been consumed.
TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+ // Handle FIN or FIN-ACK.
if s.flagIsSet(header.TCPFlagFin) {
r.rcvNxt++
@@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
r.closed = true
r.ep.readyToRead(nil)
+ // We just received a FIN, our next state depends on whether we sent a
+ // FIN already or not.
+ r.ep.mu.Lock()
+ switch r.ep.state {
+ case StateEstablished:
+ r.ep.state = StateCloseWait
+ case StateFinWait1:
+ if s.flagIsSet(header.TCPFlagAck) {
+ // FIN-ACK, transition to TIME-WAIT.
+ r.ep.state = StateTimeWait
+ } else {
+ // Simultaneous close, expecting a final ACK.
+ r.ep.state = StateClosing
+ }
+ case StateFinWait2:
+ r.ep.state = StateTimeWait
+ }
+ r.ep.mu.Unlock()
+
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
// caller is using it.
@@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
r.pendingRcvdSegments[i].decRef()
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+
+ return true
+ }
+
+ // Handle ACK (not FIN-ACK, which we handled above) during one of the
+ // shutdown states.
+ if s.flagIsSet(header.TCPFlagAck) {
+ r.ep.mu.Lock()
+ switch r.ep.state {
+ case StateFinWait1:
+ r.ep.state = StateFinWait2
+ case StateClosing:
+ r.ep.state = StateTimeWait
+ case StateLastAck:
+ r.ep.state = StateClose
+ }
+ r.ep.mu.Unlock()
}
return true
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 3b020e580..e0759225e 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -16,8 +16,6 @@ package tcp
import (
"sync"
-
- "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
@@ -58,7 +56,7 @@ func (q *segmentQueue) enqueue(s *segment) bool {
r := q.used < q.limit
if r {
q.list.PushBack(s)
- q.used += s.data.Size() + header.TCPMinimumSize
+ q.used++
}
q.mu.Unlock()
@@ -73,7 +71,7 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used -= s.data.Size() + header.TCPMinimumSize
+ q.used--
}
q.mu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index afc1d0a55..daa3e8341 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s := &sender{
ep: ep,
- sndCwnd: InitialCwnd,
- sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
@@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
return s
}
-func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
@@ -632,6 +636,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
+ // Transition to FIN-WAIT1 state since we're initiating an active close.
+ s.ep.mu.Lock()
+ s.ep.state = StateFinWait1
+ s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
@@ -779,7 +787,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding++
+ s.outstanding += s.pCount(seg)
s.writeNext = seg.Next()
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
new file mode 100644
index 000000000..4d1519860
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -0,0 +1,519 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// These tests are flaky when run under the go race detector due to some
+// iterations taking long enough that the retransmit timer can kick in causing
+// the congestion window measurements to fail due to extra packets etc.
+//
+// +build !race
+
+package tcp_test
+
+import (
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+func TestFastRecovery(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
+ rtxOffset := bytesRead - maxPayload*expected
+ for i := 0; i < 3; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ }
+
+ // Now send 7 mode duplicate acks. Each of these should cause a window
+ // inflation by 1 and cause the sender to send an extra packet.
+ for i := 0; i < 7; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the retransmit due to partial ack.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
+ t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
+ t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ }
+
+ // Receive the 10 extra packets that should have been released due to
+ // the congestion window inflation in recovery.
+ for i := 0; i < 10; i++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // A partial ACK during recovery should reduce congestion window by the
+ // number acked. Since we had "expected" packets outstanding before sending
+ // partial ack and we acked expected/2 , the cwnd and outstanding should
+ // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
+ // fast recovery). Which means the sender should not send any more packets
+ // till we ack this one.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
+
+ // At this point, the cwnd should reset to expected/2 and there are 10
+ // packets outstanding.
+ //
+ // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
+ // the same segment that takes us out of recovery. But because of that
+ // the actual cwnd at exit of recovery will be expected/2 + 1 as we
+ // acked a cwnd worth of packets which will increase the cwnd further by
+ // 1 in congestion avoidance.
+ //
+ // Now in the first iteration since there are 10 packets outstanding.
+ // We would expect to get expected/2 +1 - 10 packets. But subsequent
+ // iterations will send us expected/2 + 1 + 1 (per iteration).
+ expected = expected/2 + 1 - 10
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 10
+ }
+ expected++
+ }
+}
+
+func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // Double the number of expected packets for the next iteration.
+ expected *= 2
+ }
+}
+
+func TestCongestionAvoidance(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd/2.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
+ // acknowledgements), then the congestion avoidance part will consume
+ // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
+ // remains in the "ack count" (which will cause cwnd to be incremented
+ // once it reaches cwnd acks).
+ //
+ // So we're straight into congestion avoidance with cwnd set to
+ // expected/2 + 1.
+ //
+ // Check that packets trains of cwnd packets are sent, and that cwnd is
+ // incremented by 1 after we acknowledge each packet.
+ expected = expected/2 + 1
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ expected++
+ }
+}
+
+// cubicCwnd returns an estimate of a cubic window given the
+// originalCwnd, wMax, last congestion event time and sRTT.
+func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
+ cwnd := float64(origCwnd)
+ // We wait 50ms between each iteration so sRTT as computed by cubic
+ // should be close to 50ms.
+ elapsed := (time.Since(congEventTime) + sRTT).Seconds()
+ k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
+ wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
+ cwnd += (wtRTT - cwnd) / cwnd
+ return int(cwnd)
+}
+
+func TestCubicCongestionAvoidance(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ enableCUBIC(t, c)
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd * 0.7.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all pending data.
+ c.SendAck(790, bytesRead)
+
+ // Store away the time we sent the ACK and assuming a 200ms RTO
+ // we estimate that the sender will have an RTO 200ms from now
+ // and go back into slow start.
+ packetDropTime := time.Now().Add(200 * time.Millisecond)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 essentially putting the connection
+ // straight into congestion avoidance.
+ wMax := expected
+ // Lower expected as per cubic spec after a congestion event.
+ expected = int(float64(expected) * 0.7)
+ cwnd := expected
+ for i := 0; i < iterations; i++ {
+ // Cubic grows window independent of ACKs. Cubic Window growth
+ // is a function of time elapsed since last congestion event.
+ // As a result the congestion window does not grow
+ // deterministically in response to ACKs.
+ //
+ // We need to roughly estimate what the cwnd of the sender is
+ // based on when we sent the dupacks.
+ cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
+
+ packetsExpected := cwnd
+ for j := 0; j < packetsExpected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+ t.Logf("expected packets received, next trying to receive any extra packets that may come")
+
+ // If our estimate was correct there should be no more pending packets.
+ // We attempt to read a packet a few times with a short sleep in between
+ // to ensure that we don't see the sender send any unexpected packets.
+ unexpectedPackets := 0
+ for {
+ gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
+ if !gotPacket {
+ break
+ }
+ bytesRead += maxPayload
+ unexpectedPackets++
+ time.Sleep(1 * time.Millisecond)
+ }
+ if unexpectedPackets != 0 {
+ t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
+ }
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+ }
+}
+
+func TestRetransmit(t *testing.T) {
+ maxPayload := 32
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in two shots. Packets will only be written at the
+ // MTU size though.
+ half := data[:len(data)/2]
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+ half = data[len(data)/2:]
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Wait for a timeout and retransmit.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
+ t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ }
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the remaining data, making sure that acknowledged data is not
+ // retransmitted.
+ for offset := rtxOffset; offset < len(data); offset += maxPayload {
+ c.ReceiveAndCheckPacket(data, offset, maxPayload)
+ c.SendAck(790, offset+maxPayload)
+ }
+
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fe037602b..7d8987219 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) {
// Receive the SYN-ACK reply.
b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
ackHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) {
time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin {
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
// This is a retransmit of the FIN, ignore it.
continue
}
@@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
// We shouldn't consume a sequence number on RST.
checker.SeqNum(uint32(c.IRS)+1),
))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- // This final should be ignored because an ACK on a reset doesn't
- // mean anything.
+ // This final ACK should be ignored because an ACK on a reset doesn't mean
+ // anything.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.SeqNum(uint32(c.IRS)+1),
))
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Cause a RST to be generated by closing the read end now since we have
// unread data.
c.EP.Shutdown(tcpip.ShutdownRead)
@@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// We shouldn't consume a sequence number on RST.
checker.SeqNum(uint32(c.IRS)+1),
))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// The ACK to the FIN should now be rejected since the connection has been
// closed by a RST.
@@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
for bytesReceived != dataLen {
b := c.GetPacket()
numPackets++
- tcp := header.TCP(header.IPv4(b).Payload())
- payloadLen := len(tcp.Payload())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ payloadLen := len(tcpHdr.Payload())
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
@@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
)
pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcp.Payload(); !bytes.Equal(pdata, p) {
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
t.Fatalf("got data = %v, want = %v", p, pdata)
}
bytesReceived += payloadLen
@@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
if c.TimeStampEnabled {
// If timestamp option is enabled, echo back the timestamp and increment
// the TSEcr value included in the packet and send that back as the TSVal.
- parsedOpts := tcp.ParsedOptions()
+ parsedOpts := tcpHdr.ParsedOptions()
tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
options = tsOpt[:]
@@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
),
)
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
// Wait for retransmit.
time.Sleep(1 * time.Second)
@@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
- checker.SrcPort(tcp.SourcePort()),
- checker.SeqNum(tcp.SequenceNumber()),
+ checker.SrcPort(tcpHdr.SourcePort()),
+ checker.SeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Send SYN-ACK.
iss := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
@@ -2336,491 +2348,6 @@ func TestFinWithPartialAck(t *testing.T) {
})
}
-func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, nil)
-
- const iterations = 7
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // Double the number of expected packets for the next iteration.
- expected *= 2
- }
-}
-
-func TestCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, nil)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd/2.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
- // acknowledgements), then the congestion avoidance part will consume
- // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
- // remains in the "ack count" (which will cause cwnd to be incremented
- // once it reaches cwnd acks).
- //
- // So we're straight into congestion avoidance with cwnd set to
- // expected/2 + 1.
- //
- // Check that packets trains of cwnd packets are sent, and that cwnd is
- // incremented by 1 after we acknowledge each packet.
- expected = expected/2 + 1
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- expected++
- }
-}
-
-// cubicCwnd returns an estimate of a cubic window given the
-// originalCwnd, wMax, last congestion event time and sRTT.
-func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
- cwnd := float64(origCwnd)
- // We wait 50ms between each iteration so sRTT as computed by cubic
- // should be close to 50ms.
- elapsed := (time.Since(congEventTime) + sRTT).Seconds()
- k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
- wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
- cwnd += (wtRTT - cwnd) / cwnd
- return int(cwnd)
-}
-
-func TestCubicCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- enableCUBIC(t, c)
-
- c.CreateConnected(789, 30000, nil)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
-
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd * 0.7.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all pending data.
- c.SendAck(790, bytesRead)
-
- // Store away the time we sent the ACK and assuming a 200ms RTO
- // we estimate that the sender will have an RTO 200ms from now
- // and go back into slow start.
- packetDropTime := time.Now().Add(200 * time.Millisecond)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 essentially putting the connection
- // straight into congestion avoidance.
- wMax := expected
- // Lower expected as per cubic spec after a congestion event.
- expected = int(float64(expected) * 0.7)
- cwnd := expected
- for i := 0; i < iterations; i++ {
- // Cubic grows window independent of ACKs. Cubic Window growth
- // is a function of time elapsed since last congestion event.
- // As a result the congestion window does not grow
- // deterministically in response to ACKs.
- //
- // We need to roughly estimate what the cwnd of the sender is
- // based on when we sent the dupacks.
- cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
-
- packetsExpected := cwnd
- for j := 0; j < packetsExpected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
- t.Logf("expected packets received, next trying to receive any extra packets that may come")
-
- // If our estimate was correct there should be no more pending packets.
- // We attempt to read a packet a few times with a short sleep in between
- // to ensure that we don't see the sender send any unexpected packets.
- unexpectedPackets := 0
- for {
- gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
- if !gotPacket {
- break
- }
- bytesRead += maxPayload
- unexpectedPackets++
- time.Sleep(1 * time.Millisecond)
- }
- if unexpectedPackets != 0 {
- t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
- }
-}
-
-func TestFastRecovery(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, nil)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- for i := 0; i < 3; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
- }
-
- // Now send 7 mode duplicate acks. Each of these should cause a window
- // inflation by 1 and cause the sender to send an extra packet.
- for i := 0; i < 7; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the retransmit due to partial ack.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
- }
-
- // Receive the 10 extra packets that should have been released due to
- // the congestion window inflation in recovery.
- for i := 0; i < 10; i++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // A partial ACK during recovery should reduce congestion window by the
- // number acked. Since we had "expected" packets outstanding before sending
- // partial ack and we acked expected/2 , the cwnd and outstanding should
- // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
- // fast recovery). Which means the sender should not send any more packets
- // till we ack this one.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(790, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 10
- // packets outstanding.
- //
- // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
- // the same segment that takes us out of recovery. But because of that
- // the actual cwnd at exit of recovery will be expected/2 + 1 as we
- // acked a cwnd worth of packets which will increase the cwnd further by
- // 1 in congestion avoidance.
- //
- // Now in the first iteration since there are 10 packets outstanding.
- // We would expect to get expected/2 +1 - 10 packets. But subsequent
- // iterations will send us expected/2 + 1 + 1 (per iteration).
- expected = expected/2 + 1 - 10
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 10
- }
- expected++
- }
-}
-
-func TestRetransmit(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, nil)
-
- const iterations = 7
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in two shots. Packets will only be written at the
- // MTU size though.
- half := data[:len(data)/2]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
- half = data[len(data)/2:]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Wait for a timeout and retransmit.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
- }
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the remaining data, making sure that acknowledged data is not
- // retransmitted.
- for offset := rtxOffset; offset < len(data); offset += maxPayload {
- c.ReceiveAndCheckPacket(data, offset, maxPayload)
- c.SendAck(790, offset+maxPayload)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-}
-
func TestUpdateListenBacklog(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3008,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) {
checker.TCPFlags(header.TCPFlagAck),
),
)
- tcp := header.TCP(header.IPv4(b).Payload())
- ack := seqnum.Value(tcp.AckNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ ack := seqnum.Value(tcpHdr.AckNumber())
if ack == last {
break
}
@@ -3053,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Send some data and acknowledge the FIN.
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3074,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
- // Give the stack the chance to transition to closed state.
+ // Give the stack the chance to transition to closed state. Note that since
+ // both the sender and receiver are now closed, we effectively skip the
+ // TIME-WAIT state.
time.Sleep(1 * time.Second)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Wait for receive to be notified.
select {
case <-ch:
@@ -3668,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) {
}
}
-func TestSetCongestionControl(t *testing.T) {
+func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
- cc tcp.CongestionControlOption
- mustPass bool
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
}{
- {"reno", true},
- {"cubic", true},
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
}
for _, tc := range testCases {
@@ -3684,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) {
s := c.Stack()
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
}
- var cc tcp.CongestionControlOption
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tc.cc; got != want {
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
t.Fatalf("got congestion control: %v, want: %v", got, want)
}
})
}
}
-func TestAvailableCongestionControl(t *testing.T) {
+func TestStackAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcp.AvailableCongestionControlOption
+ var aCC tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
}
}
-func TestSetAvailableCongestionControl(t *testing.T) {
+func TestStackSetAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcp.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.AvailableCongestionControlOption
+ var cc tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
}
}
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcp.CongestionControlOption("cubic")
+ opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
@@ -3868,7 +3479,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
RcvWnd: 30000,
})
- // Receive the SYN-ACK reply.
+ // Receive the SYN-ACK reply.w
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss = seqnum.Value(tcp.SequenceNumber())
@@ -3932,12 +3543,18 @@ func TestListenBacklogFull(t *testing.T) {
time.Sleep(50 * time.Millisecond)
- // Now execute one more handshake. This should not be completed and
- // delivered on an Accept() call as the backlog is full at this point.
- irs, iss := executeHandshake(t, c, context.TestPort+uint16(listenBacklog), false /* synCookieInUse */)
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 2,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
- time.Sleep(50 * time.Millisecond)
- // Try to accept the connection.
+ // Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
@@ -3969,16 +3586,8 @@ func TestListenBacklogFull(t *testing.T) {
}
}
- // Now craft the ACK again and verify that the connection is now ready
- // to be accepted.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + uint16(listenBacklog),
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
+ // Now a new handshake must succeed.
+ executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
@@ -3994,6 +3603,7 @@ func TestListenBacklogFull(t *testing.T) {
t.Fatalf("Timed out waiting for accept")
}
}
+
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
@@ -4004,13 +3614,7 @@ func TestListenBacklogFull(t *testing.T) {
}
}
-func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 1
-
+func TestListenSynRcvdQueueFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4029,48 +3633,72 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Test acceptance.
// Start listening.
listenBacklog := 1
- portOffset := uint16(0)
if err := c.EP.Listen(listenBacklog); err != nil {
t.Fatalf("Listen failed: %v", err)
}
- executeHandshake(t, c, context.TestPort+portOffset, false)
- portOffset++
- // Wait for this to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
+ // Send two SYN's the first one should get a SYN-ACK, the
+ // second one should not get any response and is dropped as
+ // the synRcvd count will be equal to backlog.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
- nonCookieIRS, nonCookieISS := executeHandshake(t, c, context.TestPort+portOffset, false)
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
- // Since the backlog is full at this point this connection will not
- // transition out of handshake and ignore the ACK.
- //
- // At this point there should be 1 completed connection in the backlog
- // and one incomplete one pending for a final ACK and hence not ready to be
- // delivered to the endpoint.
- //
- // Now execute one more handshake. This should not be completed and
- // delivered on an Accept() call as the backlog is full at this point
- // and there is already 1 pending endpoint.
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
//
- // This one should use a SYN cookie as the synRcvdCount is equal to the
- // SynRcvdCountThreshold.
- time.Sleep(50 * time.Millisecond)
- portOffset++
- irs, iss := executeHandshake(t, c, context.TestPort+portOffset, true)
+ // NOTE: we did not complete the handshake for the previous one so the
+ // accept backlog should be empty and there should be one connection in
+ // synRcvd state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
- time.Sleep(50 * time.Millisecond)
+ // Now complete the previous connection and verify that there is a connection
+ // to accept.
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
- // Verify that there is only one acceptable connection at this point.
+ // Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err = c.EP.Accept()
+ newEP, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -4080,27 +3708,68 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
- if err != tcpip.ErrWouldBlock {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ pkt := c.GetPacket()
+ tcp = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcp.Payload()) != data {
+ t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
}
+}
- // Now send an ACK for the half completed connection
+func TestListenBacklogFullSynCookieInUse(t *testing.T) {
+ saved := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = saved
+ }()
+ tcp.SynRcvdCountThreshold = 1
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ // Start listening.
+ listenBacklog := 1
+ portOffset := uint16(0)
+ if err := c.EP.Listen(listenBacklog); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ executeHandshake(t, c, context.TestPort+portOffset, false)
+ portOffset++
+ // Wait for this to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + portOffset - 1,
+ SrcPort: context.TestPort,
DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: nonCookieIRS + 1,
- AckNum: nonCookieISS + 1,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
RcvWnd: 30000,
})
+ // The Syn should be dropped as the endpoint's backlog is full.
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+
+ // Verify that there is only one acceptable connection at this point.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
- // Verify that the connection is now delivered to the backlog.
_, _, err = c.EP.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
@@ -4116,41 +3785,15 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
- // Finally send an ACK for the connection that used a cookie and verify that
- // it's also completed and delivered.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + portOffset,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs,
- AckNum: iss,
- RcvWnd: 30000,
- })
-
- time.Sleep(50 * time.Millisecond)
- newEP, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
+ // Now verify that there are no more connections that can be accepted.
+ _, _, err = c.EP.Accept()
+ if err != tcpip.ErrWouldBlock {
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
+ t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
}
}
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
}
func TestPassiveConnectionAttemptIncrement(t *testing.T) {
@@ -4165,9 +3808,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %v", err)
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
stats := c.Stack().Stats()
want := stats.TCP.PassiveConnectionOpenings.Value() + 1
@@ -4218,18 +3867,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
}
srcPort := uint16(context.TestPort)
- // Now attempt 3 handshakes, the first two will fill up the accept and the SYN-RCVD
- // queue for the endpoint.
+ // Now attempt a handshakes it will fill up the accept backlog.
executeHandshake(t, c, srcPort, false)
// Give time for the final ACK to be processed as otherwise the next handshake could
// get accepted before the previous one based on goroutine scheduling.
time.Sleep(50 * time.Millisecond)
- irs, iss := executeHandshake(t, c, srcPort+1, false)
-
- // Wait for a short while for the accepted connection to be delivered to
- // the channel before trying to send the 3rd SYN.
- time.Sleep(40 * time.Millisecond)
want := stats.TCP.ListenOverflowSynDrop.Value() + 1
@@ -4267,26 +3910,44 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
t.Fatalf("Timed out waiting for accept")
}
}
+}
- // Now complete the next connection in SYN-RCVD state as it should
- // have dropped the final ACK to the handshake due to accept queue
- // being full.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
+func TestEndpointBindListenAcceptState(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
- // Now check that there is one more acceptable connections.
- _, _, err = c.EP.Accept()
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ aep, _, err := ep.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ aep, _, err = ep.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -4295,19 +3956,23 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
t.Fatalf("Timed out waiting for accept")
}
}
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ // Listening endpoint remains in listen state.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- // Try and accept a 3rd one this should fail.
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- ep, _, err = c.EP.Accept()
- if err == nil {
- t.Fatalf("Accept succeeded when it should have failed got: %+v", ep)
- }
-
- case <-time.After(1 * time.Second):
- }
+ ep.Close()
+ // Give worker goroutines time to receive the close notification.
+ time.Sleep(1 * time.Second)
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ // Accepted endpoint remains open when the listen endpoint is closed.
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 6e12413c6..a4d89e24d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -520,32 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if err != tcpip.ErrConnectStarted {
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -557,13 +546,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
checker.TCPFlags(header.TCPFlagSyn),
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
@@ -584,15 +576,38 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ c.Port = tcpHdr.SourcePort()
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
- c.Port = tcp.SourcePort()
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+ c.Connect(iss, rcvWnd, options)
}
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
@@ -690,6 +705,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
if err != nil {
c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -719,6 +737,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}),
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
tcpSeg := header.TCP(header.IPv4(b).Payload())
synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
@@ -782,6 +804,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// Store the source port in use by the endpoint.
c.Port = tcpSeg.SourcePort()
@@ -821,10 +846,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
if err := ep.Listen(10); err != nil {
c.t.Fatalf("Listen failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
@@ -847,6 +878,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Fatalf("Timed out waiting for accept")
}
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
return rep
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 3d52a4f31..fa7278286 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
+
+// State implements socket.Socket.State.
+func (e *endpoint) State() uint32 {
+ // TODO(b/112063468): Translate internal state to values returned by Linux.
+ return 0
+}
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index 0f155ec74..4ea684659 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -35,7 +35,7 @@ import (
)
// maxFiles determines the maximum file payload.
-const maxFiles = 16
+const maxFiles = 32
// ErrTooManyFiles is returned when too many file descriptors are mapped.
var ErrTooManyFiles = errors.New("too many files")