diff options
Diffstat (limited to 'pkg')
543 files changed, 34373 insertions, 5637 deletions
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD index 32c601a03..f5c08ea06 100644 --- a/pkg/abi/BUILD +++ b/pkg/abi/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "abi", srcs = [ diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ba233b93f..9553f164d 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -1,11 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + # Package linux contains the constants and types needed to interface with a # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - go_library( name = "linux", srcs = [ @@ -22,6 +23,8 @@ go_library( "exec.go", "fcntl.go", "file.go", + "file_amd64.go", + "file_arm64.go", "fs.go", "futex.go", "inotify.go", @@ -44,6 +47,7 @@ go_library( "sem.go", "shm.go", "signal.go", + "signalfd.go", "socket.go", "splice.go", "tcp.go", @@ -53,6 +57,7 @@ go_library( "uio.go", "utsname.go", "wait.go", + "xattr.go", ], importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index f78315ebf..6663a199c 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -16,15 +16,17 @@ package linux // Commands from linux/fcntl.h. const ( - F_DUPFD = 0x0 - F_GETFD = 0x1 - F_SETFD = 0x2 - F_GETFL = 0x3 - F_SETFL = 0x4 - F_SETLK = 0x6 - F_SETLKW = 0x7 - F_SETOWN = 0x8 - F_GETOWN = 0x9 + F_DUPFD = 0 + F_GETFD = 1 + F_SETFD = 2 + F_GETFL = 3 + F_SETFL = 4 + F_SETLK = 6 + F_SETLKW = 7 + F_SETOWN = 8 + F_GETOWN = 9 + F_SETOWN_EX = 15 + F_GETOWN_EX = 16 F_DUPFD_CLOEXEC = 1024 + 6 F_SETPIPE_SZ = 1024 + 7 F_GETPIPE_SZ = 1024 + 8 @@ -32,9 +34,9 @@ const ( // Commands for F_SETLK. const ( - F_RDLCK = 0x0 - F_WRLCK = 0x1 - F_UNLCK = 0x2 + F_RDLCK = 0 + F_WRLCK = 1 + F_UNLCK = 2 ) // Flags for fcntl. @@ -42,7 +44,7 @@ const ( FD_CLOEXEC = 00000001 ) -// Lock structure for F_SETLK. +// Flock is the lock structure for F_SETLK. type Flock struct { Type int16 Whence int16 @@ -52,3 +54,16 @@ type Flock struct { Pid int32 _ [4]byte } + +// Flags for F_SETOWN_EX and F_GETOWN_EX. +const ( + F_OWNER_TID = 0 + F_OWNER_PID = 1 + F_OWNER_PGRP = 2 +) + +// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +type FOwnerEx struct { + Type int32 + PID int32 +} diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 7d742871a..16791d03e 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -144,9 +144,13 @@ const ( ModeCharacterDevice = S_IFCHR ModeNamedPipe = S_IFIFO - ModeSetUID = 04000 - ModeSetGID = 02000 - ModeSticky = 01000 + S_ISUID = 04000 + S_ISGID = 02000 + S_ISVTX = 01000 + + ModeSetUID = S_ISUID + ModeSetGID = S_ISGID + ModeSticky = S_ISVTX ModeUserAll = 0700 ModeUserRead = 0400 @@ -186,25 +190,6 @@ const ( RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC ) -// Stat represents struct stat. -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - // SizeOfStat is the size of a Stat struct. var SizeOfStat = binary.Size(Stat{}) @@ -271,7 +256,7 @@ type Statx struct { } // FileMode represents a mode_t. -type FileMode uint +type FileMode uint16 // Permissions returns just the permission bits. func (m FileMode) Permissions() FileMode { @@ -301,6 +286,29 @@ func (m FileMode) String() string { return strings.Join(s, "|") } +// DirentType maps file types to dirent types appropriate for (struct +// dirent)::d_type. +func (m FileMode) DirentType() uint8 { + switch m.FileType() { + case ModeSocket: + return DT_SOCK + case ModeSymlink: + return DT_LNK + case ModeRegular: + return DT_REG + case ModeBlockDevice: + return DT_BLK + case ModeDirectory: + return DT_DIR + case ModeCharacterDevice: + return DT_CHR + case ModeNamedPipe: + return DT_FIFO + default: + return DT_UNKNOWN + } +} + var modeExtraBits = abi.FlagSet{ { Flag: ModeSetUID, diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go new file mode 100644 index 000000000..74c554be6 --- /dev/null +++ b/pkg/abi/linux/file_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go new file mode 100644 index 000000000..f16c07589 --- /dev/null +++ b/pkg/abi/linux/file_arm64.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [2]int32 +} diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index b416e3472..2c652baa2 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -92,3 +92,10 @@ const ( SYNC_FILE_RANGE_WRITE = 2 SYNC_FILE_RANGE_WAIT_AFTER = 4 ) + +// Flag argument to renameat2(2), from include/uapi/linux/fs.h. +const ( + RENAME_NOREPLACE = (1 << 0) // Don't overwrite target. + RENAME_EXCHANGE = (1 << 1) // Exchange src and dst. + RENAME_WHITEOUT = (1 << 2) // Whiteout src. +) diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 152f6b144..3898d2314 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -325,3 +325,9 @@ const ( RTA_SPORT = 28 RTA_DPORT = 29 ) + +// Route flags, from include/uapi/linux/route.h. +const ( + RTF_GATEWAY = 0x2 + RTF_UP = 0x1 +) diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go new file mode 100644 index 000000000..85fad9956 --- /dev/null +++ b/pkg/abi/linux/signalfd.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + // SFD_NONBLOCK is a signalfd(2) flag. + SFD_NONBLOCK = 00004000 + + // SFD_CLOEXEC is a signalfd(2) flag. + SFD_CLOEXEC = 02000000 +) + +// SignalfdSiginfo is the siginfo encoding for signalfds. +type SignalfdSiginfo struct { + Signo uint32 + Errno int32 + Code int32 + PID uint32 + UID uint32 + FD int32 + TID uint32 + Band uint32 + Overrun uint32 + TrapNo uint32 + Status int32 + Int int32 + Ptr uint64 + UTime uint64 + STime uint64 + Addr uint64 + AddrLSB uint16 + _ [48]uint8 +} diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d5b731390..766ee4014 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -256,6 +256,17 @@ type SockAddrInet6 struct { Scope_id uint32 } +// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +type SockAddrLink struct { + Family uint16 + Protocol uint16 + InterfaceIndex int32 + ARPHardwareType uint16 + PacketType byte + HardwareAddrLen byte + HardwareAddr [8]byte +} + // UnixPathMax is the maximum length of the path in an AF_UNIX socket. // // From uapi/linux/un.h. @@ -278,6 +289,7 @@ type SockAddr interface { func (s *SockAddrInet) implementsSockAddr() {} func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrLink) implementsSockAddr() {} func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} @@ -410,6 +422,15 @@ type ControlMessageRights []int32 // ControlMessageRights. const SizeOfControlMessageRight = 4 +// SizeOfControlMessageInq is the size of a TCP_INQ control message. +const SizeOfControlMessageInq = 4 + +// SizeOfControlMessageTOS is the size of an IP_TOS control message. +const SizeOfControlMessageTOS = 1 + +// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. +const SizeOfControlMessageTClass = 4 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go new file mode 100644 index 000000000..a3b6406fa --- /dev/null +++ b/pkg/abi/linux/xattr.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Constants for extended attributes. +const ( + XATTR_NAME_MAX = 255 + XATTR_SIZE_MAX = 65536 + + XATTR_CREATE = 1 + XATTR_REPLACE = 2 + + XATTR_USER_PREFIX = "user." + XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX) +) diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 39d253b98..6bc486b62 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 47ab65346..36beaade9 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -7,6 +8,7 @@ go_library( srcs = [ "atomic_bitops.go", "atomic_bitops_amd64.s", + "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], importpath = "gvisor.dev/gvisor/pkg/atomicbitops", diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomic_bitops.go index 63aa2b7f1..fcc41a9ea 100644 --- a/pkg/atomicbitops/atomic_bitops.go +++ b/pkg/atomicbitops/atomic_bitops.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 // Package atomicbitops provides basic bitwise operations in an atomic way. // The implementation on amd64 leverages the LOCK prefix directly instead of -// relying on the generic cas primitives. +// relying on the generic cas primitives, and the arm64 leverages the LDAXR +// and STLXR pair primitives. // // WARNING: the bitwise ops provided in this package doesn't imply any memory // ordering. Using them to construct locks must employ proper memory barriers. diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomic_bitops_arm64.s new file mode 100644 index 000000000..97f8808c1 --- /dev/null +++ b/pkg/atomicbitops/atomic_bitops_arm64.s @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +#include "textflag.h" + +TEXT ·AndUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ANDW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ORRW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + EORW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint32(SB),$0-20 + MOVD addr+0(FP), R0 + MOVW old+8(FP), R1 + MOVW new+12(FP), R2 + +again: + LDAXRW (R0), R3 + CMPW R1, R3 + BNE done + STLXRW R2, (R0), R4 + CBNZ R4, again +done: + MOVW R3, prev+16(FP) + RET + +TEXT ·AndUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + AND R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + ORR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + EOR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint64(SB),$0-32 + MOVD addr+0(FP), R0 + MOVD old+8(FP), R1 + MOVD new+16(FP), R2 + +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE done + STLXR R2, (R0), R4 + CBNZ R4, again +done: + MOVD R3, prev+24(FP) + RET + +TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + CBZ R1, fail + ADDW $1, R1 + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET + +TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + SUBSW $1, R1, R1 + BEQ fail + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomic_bitops_common.go index b2a943dcb..85163ad62 100644 --- a/pkg/atomicbitops/atomic_bitops_common.go +++ b/pkg/atomicbitops/atomic_bitops_common.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package atomicbitops diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 09d6c2c1f..543fb54bf 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 0c2dde4f8..93b88a29a 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,17 +1,18 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - go_library( name = "bits", srcs = [ "bits.go", "bits32.go", "bits64.go", - "uint64_arch_amd64.go", + "uint64_arch.go", "uint64_arch_amd64_asm.s", + "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], importpath = "gvisor.dev/gvisor/pkg/bits", diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go index faccaa61a..9f23eff77 100644 --- a/pkg/bits/uint64_arch_amd64.go +++ b/pkg/bits/uint64_arch.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package bits diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s new file mode 100644 index 000000000..814ba562d --- /dev/null +++ b/pkg/bits/uint64_arch_arm64_asm.s @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +TEXT ·TrailingZeros64(SB),$0-16 + MOVD x+0(FP), R0 + RBIT R0, R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD R0, ret+8(FP) + RET + +TEXT ·MostSignificantOne64(SB),$0-16 + MOVD x+0(FP), R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD $63, R1 + SUBS R0, R1, R0 // ret = 63 - CLZ + BPL end + MOVD $64, R0 // x == 0 +end: + MOVD R0, ret+8(FP) + RET diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go index 7dd2d1480..9dd2098d1 100644 --- a/pkg/bits/uint64_arch_generic.go +++ b/pkg/bits/uint64_arch_generic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package bits diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index b692aa3b1..fba5643e8 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "bpf", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index cdec96df1..a0b21d4bd 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 830e19e07..ed111fd2a 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "cpuid", diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 5d61dc2ff..d37047368 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -183,6 +183,33 @@ const ( X86FeatureAVX512VBMI X86FeatureUMIP X86FeaturePKU + X86FeatureOSPKE + X86FeatureWAITPKG + X86FeatureAVX512_VBMI2 + _ // ecx bit 7 is reserved + X86FeatureGFNI + X86FeatureVAES + X86FeatureVPCLMULQDQ + X86FeatureAVX512_VNNI + X86FeatureAVX512_BITALG + X86FeatureTME + X86FeatureAVX512_VPOPCNTDQ + _ // ecx bit 15 is reserved + X86FeatureLA57 + // ecx bits 17-21 are reserved + _ + _ + _ + _ + _ + X86FeatureRDPID + // ecx bits 23-24 are reserved + _ + _ + X86FeatureCLDEMOTE + _ // ecx bit 26 is reserved + X86FeatureMOVDIRI + X86FeatureMOVDIR64B ) // Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX. @@ -353,9 +380,24 @@ var x86FeatureStrings = map[Feature]string{ X86FeatureAVX512VL: "avx512vl", // Block 3. - X86FeatureAVX512VBMI: "avx512vbmi", - X86FeatureUMIP: "umip", - X86FeaturePKU: "pku", + X86FeatureAVX512VBMI: "avx512vbmi", + X86FeatureUMIP: "umip", + X86FeaturePKU: "pku", + X86FeatureOSPKE: "ospke", + X86FeatureWAITPKG: "waitpkg", + X86FeatureAVX512_VBMI2: "avx512_vbmi2", + X86FeatureGFNI: "gfni", + X86FeatureVAES: "vaes", + X86FeatureVPCLMULQDQ: "vpclmulqdq", + X86FeatureAVX512_VNNI: "avx512_vnni", + X86FeatureAVX512_BITALG: "avx512_bitalg", + X86FeatureTME: "tme", + X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq", + X86FeatureLA57: "la57", + X86FeatureRDPID: "rdpid", + X86FeatureCLDEMOTE: "cldemote", + X86FeatureMOVDIRI: "movdiri", + X86FeatureMOVDIR64B: "movdir64b", // Block 4. X86FeatureXSAVEOPT: "xsaveopt", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9961baaa9..0b4b7cc44 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -24,6 +25,7 @@ go_library( proto_library( name = "eventchannel_proto", srcs = ["event.proto"], + visibility = ["//:sandbox"], ) go_proto_library( diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index 785c685a0..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index e54e7371c..56495cbd9 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index bd1d614b6..e590a71ba 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -18,6 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", + "//pkg/syncutil", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index d59159912..e7c3a3a0b 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { *ep.dataLen() = w.Len() // Return control to the client. + raceBecomeInactive() if err := ep.futexSwitchToPeer(); err != nil { return err } @@ -112,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error { return nil case epsBlocked | epsShutdown: atomic.AddInt32(&ep.ctrl.state, -epsBlocked) - return shutdownError{} + return ShutdownError{} default: // Most likely due to ep.enterFutexWait() being called concurrently // from multiple goroutines. diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 991018684..3cdb576e1 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() { // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), // ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer -// Endpoint, to unblock and return errors. It does not wait for concurrent -// calls to return. Successive calls to Shutdown have no effect. +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool { return atomic.LoadUint32(&ep.shutdown) != 0 } -type shutdownError struct{} +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} // Error implements error.Error. -func (shutdownError) Error() string { +func (ShutdownError) Error() string { return "flipcall connection shutdown" } @@ -180,7 +182,11 @@ const ( // Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), // ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - return ep.ctrlConnect() + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then @@ -192,6 +198,7 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -218,9 +225,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then // they can only shoot themselves in the foot. *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlRoundTrip(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -240,6 +249,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlWakeLast(); err != nil { return err } diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 435e4eeae..168a487ec 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -62,6 +62,9 @@ func (c *testConnection) destroy() { } func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 var serverRun sync.WaitGroup serverRun.Add(1) go func() { @@ -71,11 +74,19 @@ func testSendRecv(t *testing.T, c *testConnection) { t.Errorf("server Endpoint.RecvFirst() failed: %v", err) return } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) + } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { t.Errorf("server Endpoint.SendRecv() failed: %v", err) return } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) + } t.Logf("server Endpoint got packet 3") }() defer func() { @@ -89,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } t.Logf("client Endpoint sending packet 1 and waiting for packet 2") if _, err := c.clientEP.SendRecv(0); err != nil { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } t.Logf("client Endpoint got packet 2, sending packet 3") if err := c.clientEP.SendLast(0); err != nil { t.Fatalf("client Endpoint.SendLast() failed: %v", err) diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 73e6eef29..27b8939fc 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -17,6 +17,8 @@ package flipcall import ( "reflect" "unsafe" + + "gvisor.dev/gvisor/pkg/syncutil" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte { bsReflect.Cap = int(ep.dataCap) return bs } + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if syncutil.RaceEnabled { + syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if syncutil.RaceEnabled { + syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index b127a2bbb..168c1ccff 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error { if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { switch cs := atomic.LoadUint32(ep.connState()); cs { case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } @@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error { return nil case ep.inactiveState: if ep.isShutdownLocally() { - return shutdownError{} + return ShutdownError{} } if err := ep.futexWaitConnState(ep.inactiveState); err != nil { return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } continue case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 11716af81..0c5f50397 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index e6a8dbd02..4b9321711 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 8f3defa25..34d2673ef 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,5 +1,6 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index c8e923a74..a5d980d14 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 12615240c..fc5f5779b 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 3b8a691f4..dd6ca6d39 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,5 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -21,6 +23,12 @@ proto_library( visibility = ["//:sandbox"], ) +cc_proto_library( + name = "metric_cc_proto", + visibility = ["//:sandbox"], + deps = [":metric_proto"], +) + go_proto_library( name = "metric_go_proto", importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index c6737bf97..f32244c69 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], @@ -19,11 +20,14 @@ go_library( "pool.go", "server.go", "transport.go", + "transport_flipcall.go", "version.go", ], importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", + "//pkg/fdchannel", + "//pkg/flipcall", "//pkg/log", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 7dc20aeef..221516c6c 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -20,6 +20,8 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -77,6 +79,47 @@ type Client struct { // fidPool is the collection of available fids. fidPool pool + // messageSize is the maximum total size of a message. + messageSize uint32 + + // payloadSize is the maximum payload size of a read or write. + // + // For large reads and writes this means that the read or write is + // broken up into buffer-size/payloadSize requests. + payloadSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + + // sendRecv is the transport function. + // + // This is determined dynamically based on whether or not the server + // supports flipcall channels (preferred as it is faster and more + // efficient, and does not require tags). + sendRecv func(message, message) error + + // -- below corresponds to sendRecvChannel -- + + // channelsMu protects channels. + channelsMu sync.Mutex + + // channelsWg counts the number of channels for which channel.active == + // true. + channelsWg sync.WaitGroup + + // channels is the set of all initialized channels. + channels []*channel + + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel + + // -- below corresponds to sendRecvLegacy -- + // pending is the set of pending messages. pending map[Tag]*response pendingMu sync.Mutex @@ -89,25 +132,12 @@ type Client struct { // Whoever writes to this channel is permitted to call recv. When // finished calling recv, this channel should be emptied. recvr chan bool - - // messageSize is the maximum total size of a message. - messageSize uint32 - - // payloadSize is the maximum payload size of a read or write - // request. For large reads and writes this means that the - // read or write is broken up into buffer-size/payloadSize - // requests. - payloadSize uint32 - - // version is the agreed upon version X of 9P2000.L.Google.X. - // version 0 implies 9P2000.L. - version uint32 } // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client return nil, ErrBadVersionString } for { + // Always exchange the version using the legacy version of the + // protocol. If the protocol supports flipcall, then we switch + // our sendRecv function to use that functionality. Otherwise, + // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion) + err := c.sendRecvLegacy(&Tversion{ + Version: versionString(requested), + MSize: messageSize, + }, &rversion) // The server told us to try again with a lower version. if err == syscall.EAGAIN { @@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.version = version break } + + // Can we switch to use the more advanced channels and create + // independent channels for communication? Prefer it if possible. + if versionSupportsFlipcall(c.version) { + // Attempt to initialize IPC-based communication. + for i := 0; i < channelsPerClient; i++ { + if err := c.openChannel(i); err != nil { + log.Warningf("error opening flipcall channel: %v", err) + break // Stop. + } + } + if len(c.channels) >= 1 { + // At least one channel created. + c.sendRecv = c.sendRecvChannel + } else { + // Channel setup failed; fallback. + c.sendRecv = c.sendRecvLegacy + } + } else { + // No channels available: use the legacy mechanism. + c.sendRecv = c.sendRecvLegacy + } + + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } +// watch watches the given socket and releases resources on hangup events. +// +// This is intended to be called as a goroutine. +func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + + events := []unix.PollFd{ + unix.PollFd{ + Fd: int32(socket.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + continue + } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() +} + +// openChannel attempts to open a client channel. +// +// Note that this function returns naked errors which should not be propagated +// directly to a caller. It is expected that the errors will be logged and a +// fallback path will be used instead. +func (c *Client) openChannel(id int) error { + var ( + rchannel0 Rchannel + rchannel1 Rchannel + res = new(channel) + ) + + // Open the data channel. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 0, + }, &rchannel0); err != nil { + return fmt.Errorf("error handling Tchannel message: %v", err) + } + if rchannel0.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on primary channel") + } + + // We don't need to hold this. + defer rchannel0.FilePayload().Close() + + // Open the channel for file descriptors. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 1, + }, &rchannel1); err != nil { + return err + } + if rchannel1.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on file descriptor channel") + } + + // Construct the endpoints. + res.desc = flipcall.PacketWindowDescriptor{ + FD: rchannel0.FilePayload().FD(), + Offset: int64(rchannel0.Offset), + Length: int(rchannel0.Length), + } + if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { + rchannel1.FilePayload().Close() + return err + } + + // The fds channel owns the control payload, and it will be closed when + // the channel object is closed. + res.fds.Init(rchannel1.FilePayload().Release()) + + // Save the channel. + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + c.channels = append(c.channels, res) + c.availableChannels = append(c.availableChannels, res) + return nil +} + // handleOne handles a single incoming message. // // This should only be called with the token from recvr. Note that the received @@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error { } } -// sendRecv performs a roundtrip message exchange. +// sendRecvLegacy performs a roundtrip message exchange. // // This is called by internal functions. -func (c *Client) sendRecv(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) error { tag, ok := c.tagPool.Get() if !ok { return ErrOutOfTags @@ -296,12 +479,77 @@ func (c *Client) sendRecv(t message, r message) error { return nil } +// sendRecvChannel uses channels to send a message. +func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. + c.channelsMu.Lock() + if len(c.availableChannels) == 0 { + c.channelsMu.Unlock() + return c.sendRecvLegacy(t, r) + } + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true + c.channelsWg.Add(1) + c.channelsMu.Unlock() + + // Ensure that it's connected. + if !ch.connected { + ch.connected = true + if err := ch.data.Connect(); err != nil { + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO + } + } + + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + _, retErr := ch.recv(r, rsz) + + // Release the channel. + c.channelsMu.Lock() + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.channelsMu.Unlock() + c.channelsWg.Done() + + return retErr +} + // Version returns the negotiated 9P2000.L.Google version number. func (c *Client) Version() uint32 { return c.version } -// Close closes the underlying socket. -func (c *Client) Close() error { - return c.socket.Close() +// Close closes the underlying socket and channels. +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index a6cc0617e..de9357389 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -17,7 +17,6 @@ package p9 import ( "fmt" "io" - "runtime" "sync/atomic" "syscall" @@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) { // newFile returns a new client file. func (c *Client) newFile(fid FID) *clientFile { - cf := &clientFile{ + return &clientFile{ client: c, fid: fid, } - - // Make sure the file is closed. - runtime.SetFinalizer(cf, (*clientFile).Close) - - return cf } // clientFile is provided to clients. @@ -192,7 +186,6 @@ func (c *clientFile) Remove() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the remove message. if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil { @@ -214,7 +207,6 @@ func (c *clientFile) Close() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the close message. if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil { diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 87b2dd61e..29a0afadf 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -35,23 +35,23 @@ func TestVersion(t *testing.T) { go s.Handle(serverSocket) // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString()) + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) if err != nil { t.Fatalf("got %v, expected nil", err) } // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN { + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { t.Errorf("got %v expected %v", err, syscall.EAGAIN) } @@ -60,3 +60,45 @@ func TestVersion(t *testing.T) { t.Errorf("got %v expected %v", err, syscall.EINVAL) } } + +func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + // See above. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // See above. + s := NewServer(nil) + go s.Handle(serverSocket) + + // See above. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + b.Fatalf("got %v, expected nil", err) + } + + // Initialize messages. + sendRecv := fn(c) + tversion := &Tversion{ + Version: versionString(highestSupportedVersion), + MSize: DefaultMessageSize, + } + rversion := new(Rversion) + + // Run in a loop. + for i := 0; i < b.N; i++ { + if err := sendRecv(tversion, rversion); err != nil { + b.Fatalf("got unexpected err: %v", err) + } + } +} + +func BenchmarkSendRecvLegacy(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) +} + +func BenchmarkSendRecvChannel(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) +} diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 907445e15..96d1f2a8e 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -116,7 +116,7 @@ type File interface { // N.B. The server must resolve any lazy paths when open is called. // After this point, read and write may be called on files with no // deletion check, so resolving in the data path is not viable. - Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // Read reads from this file. Open must be called first. // diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 999b4f684..b9582c07f 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -257,7 +257,6 @@ func CanOpen(mode FileMode) bool { // handle implements handler.handle. func (t *Tlopen) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -272,15 +271,15 @@ func (t *Tlopen) handle(cs *connState) message { return newErr(syscall.EINVAL) } - // Are flags valid? - flags := t.Flags &^ OpenFlagsIgnoreMask - if flags&^OpenFlagsModeMask != 0 { - return newErr(syscall.EINVAL) - } - - // Is this an attempt to open a directory as writable? Don't accept. - if ref.mode.IsDir() && flags != ReadOnly { - return newErr(syscall.EINVAL) + if ref.mode.IsDir() { + // Directory must be opened ReadOnly. + if t.Flags&OpenFlagsModeMask != ReadOnly { + return newErr(syscall.EISDIR) + } + // Directory not truncatable. + if t.Flags&OpenTruncate != 0 { + return newErr(syscall.EISDIR) + } } var ( @@ -294,7 +293,6 @@ func (t *Tlopen) handle(cs *connState) message { return syscall.EINVAL } - // Do the open. osFile, qid, ioUnit, err = ref.file.Open(t.Flags) return err }); err != nil { @@ -305,16 +303,16 @@ func (t *Tlopen) handle(cs *connState) message { ref.opened = true ref.openFlags = t.Flags - return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile} + rlopen := &Rlopen{QID: qid, IoUnit: ioUnit} + rlopen.SetFilePayload(osFile) + return rlopen } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return nil, syscall.EBADF @@ -364,7 +362,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { // Replace the FID reference. cs.InsertFID(t.FID, newRef) - return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil + rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}} + rlcreate.SetFilePayload(osFile) + return rlcreate, nil } // handle implements handler.handle. @@ -386,12 +386,10 @@ func (t *Tsymlink) handle(cs *connState) message { } func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -422,19 +420,16 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { // handle implements handler.handle. func (t *Tlink) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.Target) if !ok { return newErr(syscall.EBADF) @@ -463,7 +458,6 @@ func (t *Tlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Trenameat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.OldName); err != nil { return newErr(err) } @@ -471,14 +465,12 @@ func (t *Trenameat) handle(cs *connState) message { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.OldDirectory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.NewDirectory) if !ok { return newErr(syscall.EBADF) @@ -519,12 +511,10 @@ func (t *Trenameat) handle(cs *connState) message { // handle implements handler.handle. func (t *Tunlinkat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -573,19 +563,16 @@ func (t *Tunlinkat) handle(cs *connState) message { // handle implements handler.handle. func (t *Trename) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the target. refTarget, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -637,7 +624,6 @@ func (t *Trename) handle(cs *connState) message { // handle implements handler.handle. func (t *Treadlink) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -665,7 +651,6 @@ func (t *Treadlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Tread) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -704,7 +689,6 @@ func (t *Tread) handle(cs *connState) message { // handle implements handler.handle. func (t *Twrite) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -743,12 +727,10 @@ func (t *Tmknod) handle(cs *connState) message { } func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -787,12 +769,10 @@ func (t *Tmkdir) handle(cs *connState) message { } func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -823,7 +803,6 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { // handle implements handler.handle. func (t *Tgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -852,7 +831,6 @@ func (t *Tgetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tsetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -879,7 +857,6 @@ func (t *Tsetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tallocate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -913,7 +890,6 @@ func (t *Tallocate) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrwalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -926,7 +902,6 @@ func (t *Txattrwalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrcreate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -939,7 +914,6 @@ func (t *Txattrcreate) handle(cs *connState) message { // handle implements handler.handle. func (t *Treaddir) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -973,7 +947,6 @@ func (t *Treaddir) handle(cs *connState) message { // handle implements handler.handle. func (t *Tfsync) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -997,7 +970,6 @@ func (t *Tfsync) handle(cs *connState) message { // handle implements handler.handle. func (t *Tstatfs) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1188,7 +1160,6 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI // handle implements handler.handle. func (t *Twalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1209,7 +1180,6 @@ func (t *Twalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Twalkgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1266,7 +1236,6 @@ func (t *Tumknod) handle(cs *connState) message { // handle implements handler.handle. func (t *Tlconnect) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1287,5 +1256,47 @@ func (t *Tlconnect) handle(cs *connState) message { return newErr(err) } - return &Rlconnect{File: osFile} + rlconnect := &Rlconnect{} + rlconnect.SetFilePayload(osFile) + return rlconnect +} + +// handle implements handler.handle. +func (t *Tchannel) handle(cs *connState) message { + // Ensure that channels are enabled. + if err := cs.initializeChannels(); err != nil { + return newErr(err) + } + + ch := cs.lookupChannel(t.ID) + if ch == nil { + return newErr(syscall.ENOSYS) + } + + // Return the payload. Note that we need to duplicate the file + // descriptor for the channel allocator, because sending is a + // destructive operation between sendRecvLegacy (and now the newer + // channel send operations). Same goes for the client FD. + rchannel := &Rchannel{ + Offset: uint64(ch.desc.Offset), + Length: uint64(ch.desc.Length), + } + switch t.Control { + case 0: + // Open the main data channel. + mfd, err := syscall.Dup(int(cs.channelAlloc.FD())) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(mfd)) + case 1: + cfd, err := syscall.Dup(ch.client.FD()) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(cfd)) + default: + return newErr(syscall.EINVAL) + } + return rchannel } diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index fd9eb1c5d..ffdd7e8c6 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -64,6 +64,21 @@ type filer interface { SetFilePayload(*fd.FD) } +// filePayload embeds a File object. +type filePayload struct { + File *fd.FD +} + +// FilePayload returns the file payload. +func (f *filePayload) FilePayload() *fd.FD { + return f.File +} + +// SetFilePayload sets the received file. +func (f *filePayload) SetFilePayload(file *fd.FD) { + f.File = file +} + // Tversion is a version request. type Tversion struct { // MSize is the message size to use. @@ -524,10 +539,7 @@ type Rlopen struct { // IoUnit is the recommended I/O unit. IoUnit uint32 - // File may be attached via the socket. - // - // This is an extension specific to this package. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType { return MsgRlopen } -// FilePayload returns the file payload. -func (r *Rlopen) FilePayload() *fd.FD { - return r.File -} - -// SetFilePayload sets the received file. -func (r *Rlopen) SetFilePayload(file *fd.FD) { - r.File = file -} - // String implements fmt.Stringer. func (r *Rlopen) String() string { return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) @@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string { // Rlconnect is a connect response. type Rlconnect struct { - // File is a host socket. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType { return MsgRlconnect } -// FilePayload returns the file payload. -func (r *Rlconnect) FilePayload() *fd.FD { - return r.File +// String implements fmt.Stringer. +func (r *Rlconnect) String() string { + return fmt.Sprintf("Rlconnect{File: %v}", r.File) } -// SetFilePayload sets the received file. -func (r *Rlconnect) SetFilePayload(file *fd.FD) { - r.File = file +// Tchannel creates a new channel. +type Tchannel struct { + // ID is the channel ID. + ID uint32 + + // Control is 0 if the Rchannel response should provide the flipcall + // component of the channel, and 1 if the Rchannel response should + // provide the fdchannel component of the channel. + Control uint32 +} + +// Decode implements encoder.Decode. +func (t *Tchannel) Decode(b *buffer) { + t.ID = b.Read32() + t.Control = b.Read32() +} + +// Encode implements encoder.Encode. +func (t *Tchannel) Encode(b *buffer) { + b.Write32(t.ID) + b.Write32(t.Control) +} + +// Type implements message.Type. +func (*Tchannel) Type() MsgType { + return MsgTchannel } // String implements fmt.Stringer. -func (r *Rlconnect) String() string { - return fmt.Sprintf("Rlconnect{File: %v}", r.File) +func (t *Tchannel) String() string { + return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control) +} + +// Rchannel is the channel response. +type Rchannel struct { + Offset uint64 + Length uint64 + filePayload +} + +// Decode implements encoder.Decode. +func (r *Rchannel) Decode(b *buffer) { + r.Offset = b.Read64() + r.Length = b.Read64() +} + +// Encode implements encoder.Encode. +func (r *Rchannel) Encode(b *buffer) { + b.Write64(r.Offset) + b.Write64(r.Length) +} + +// Type implements message.Type. +func (*Rchannel) Type() MsgType { + return MsgRchannel +} + +// String implements fmt.Stringer. +func (r *Rchannel) String() string { + return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } const maxCacheSize = 3 @@ -2356,4 +2409,6 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) + msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index e12831dbd..d3090535a 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -32,21 +32,21 @@ import ( type OpenFlags uint32 const ( - // ReadOnly is a Topen and Tcreate flag indicating read-only mode. + // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode. ReadOnly OpenFlags = 0 - // WriteOnly is a Topen and Tcreate flag indicating write-only mode. + // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode. WriteOnly OpenFlags = 1 - // ReadWrite is a Topen flag indicates read-write mode. + // ReadWrite is a Tlopen flag indicates read-write mode. ReadWrite OpenFlags = 2 // OpenFlagsModeMask is a mask of valid OpenFlags mode bits. OpenFlagsModeMask OpenFlags = 3 - // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen. - // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h. - OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000 + // OpenTruncate is a Tlopen flag indicating that the opened file should be + // truncated. + OpenTruncate OpenFlags = 01000 ) // ConnectFlags is the mode passed to Connect operations. @@ -71,25 +71,32 @@ const ( // OSFlags converts a p9.OpenFlags to an int compatible with open(2). func (o OpenFlags) OSFlags() int { - return int(o & OpenFlagsModeMask) + // "flags contains Linux open(2) flags bits" - 9P2000.L + return int(o) } // String implements fmt.Stringer. func (o OpenFlags) String() string { - switch o { + var buf strings.Builder + switch mode := o & OpenFlagsModeMask; mode { case ReadOnly: - return "ReadOnly" + buf.WriteString("ReadOnly") case WriteOnly: - return "WriteOnly" + buf.WriteString("WriteOnly") case ReadWrite: - return "ReadWrite" - case OpenFlagsModeMask: - return "OpenFlagsModeMask" - case OpenFlagsIgnoreMask: - return "OpenFlagsIgnoreMask" + buf.WriteString("ReadWrite") default: - return "UNDEFINED" + fmt.Fprintf(&buf, "%#o", mode) } + otherFlags := o &^ OpenFlagsModeMask + if otherFlags&OpenTruncate != 0 { + buf.WriteString("|OpenTruncate") + otherFlags &^= OpenTruncate + } + if otherFlags != 0 { + fmt.Fprintf(&buf, "|%#o", otherFlags) + } + return buf.String() } // Tag is a message tag. @@ -378,6 +385,8 @@ const ( MsgRlconnect = 137 MsgTallocate = 138 MsgRallocate = 139 + MsgTchannel = 250 + MsgRchannel = 251 ) // QIDType represents the file type for QIDs. @@ -812,7 +821,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { attr.Mode = FileMode(s.Mode) } if req.NLink { - attr.NLink = s.Nlink + attr.NLink = uint64(s.Nlink) } if req.UID { attr.UID = UID(s.Uid) diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 6e939a49a..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index fe649c2e8..6e758148d 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1044,11 +1044,11 @@ func TestReaddir(t *testing.T) { if _, err := f.Readdir(0, 1); err != syscall.EINVAL { t.Errorf("readdir got %v, wanted EINVAL", err) } - if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } - if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } backend.EXPECT().Open(p9.ReadOnly).Times(1) if _, _, _, err := f.Open(p9.ReadOnly); err != nil { @@ -1065,75 +1065,93 @@ func TestReaddir(t *testing.T) { func TestOpen(t *testing.T) { type openTest struct { name string - mode p9.OpenFlags + flags p9.OpenFlags err error match func(p9.FileMode) bool } cases := []openTest{ { - name: "invalid", - mode: ^p9.OpenFlagsModeMask, - err: syscall.EINVAL, - match: func(p9.FileMode) bool { return true }, - }, - { name: "not-openable-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "directory-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-read-write", - mode: p9.ReadWrite, - err: syscall.EINVAL, + flags: p9.ReadWrite, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-write-only", - mode: p9.WriteOnly, - err: syscall.EINVAL, + flags: p9.WriteOnly, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, }, { name: "write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, { name: "read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "directory-read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "write-only-truncate", + flags: p9.WriteOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "read-write-truncate", + flags: p9.ReadWrite | p9.OpenTruncate, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, } - // Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice // - returning a file works as expected for name := range newTypeMap(nil) { @@ -1171,25 +1189,25 @@ func TestOpen(t *testing.T) { // Attempt the given open. if tc.err != nil { // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.mode); err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + if _, _, _, err := f.Open(tc.flags); err != tc.err { + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return } // Run an FD test, since we expect success. fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.mode) + backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) + recv, _, _, err := f.Open(tc.flags) if err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return recv }) // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL { - t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err) + if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL { + t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) } // Ensure that all illegal operations fail. @@ -2127,3 +2145,98 @@ func TestConcurrency(t *testing.T) { } } } + +func TestReadWriteConcurrent(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const ( + instances = 10 + iterations = 10000 + dataSize = 1024 + ) + var ( + dataSets [instances][dataSize]byte + backends [instances]*Mock + files [instances]p9.File + ) + + // Walk to the file normally. + for i := 0; i < instances; i++ { + _, backends[i], files[i] = walkHelper(h, "file", root) + defer files[i].Close() + } + + // Open the files. + for i := 0; i < instances; i++ { + backends[i].EXPECT().Open(p9.ReadWrite) + if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + } + + // Initialize random data for each instance. + for i := 0; i < instances; i++ { + if _, err := rand.Read(dataSets[i][:]); err != nil { + t.Fatalf("error initializing dataSet#%d, got %v", i, err) + } + } + + // Define our random read/write mechanism. + randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { + // Prepare the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if n := copy(p, data); n != len(data) { + // Note that we have to assert the result here, as the Return statement + // below cannot be dynamic: it will be bound before this call is made. + h.t.Errorf("wanted length %d, got %d", len(data), n) + } + }).Return(len(data), nil) + + // Execute the read. + if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) + return // No sense doing check below. + } + if !bytes.Equal(test, data) { + t.Errorf("data integrity failed during read") // Not as expected. + } + } + randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { + // Prepare the backend. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if !bytes.Equal(p, data) { + h.t.Errorf("data integrity failed during write") // Not as expected. + } + }).Return(len(data), nil) + + // Execute the write. + if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) + } + } + randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { + test := make([]byte, len(data)) + for i := 0; i < n; i++ { + if rand.Intn(2) == 0 { + randRead(h, backend, f, data, test) + } else { + randWrite(h, backend, f, data) + } + } + } + + // Start reading and writing. + var wg sync.WaitGroup + for i := 0; i < instances; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) + }(i) + } + wg.Wait() +} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 95846e5f7..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } @@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { }() // Create the client. - client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString()) + client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) if err != nil { serverSocket.Close() clientSocket.Close() diff --git a/pkg/p9/server.go b/pkg/p9/server.go index b294efbb0..40b8fa023 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -21,6 +21,9 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -45,7 +48,6 @@ type Server struct { } // NewServer returns a new server. -// func NewServer(attacher Attacher) *Server { return &Server{ attacher: attacher, @@ -85,6 +87,8 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // -- below relates to the legacy handler -- + // recvOkay indicates that a receive may start. recvOkay chan bool @@ -93,6 +97,20 @@ type connState struct { // sendDone is signalled when a send is finished. sendDone chan error + + // -- below relates to the flipcall handler -- + + // channelMu protects below. + channelMu sync.Mutex + + // channelWg represents active workers. + channelWg sync.WaitGroup + + // channelAlloc allocates channel memory. + channelAlloc *flipcall.PacketWindowAllocator + + // channels are the set of initialized channels. + channels []*channel } // fidRef wraps a node and tracks references. @@ -386,6 +404,105 @@ func (cs *connState) WaitTag(t Tag) { <-ch } +// initializeChannels initializes all channels. +// +// This is a no-op if channels are already initialized. +func (cs *connState) initializeChannels() (err error) { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + + // Initialize our channel allocator. + if cs.channelAlloc == nil { + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return err + } + cs.channelAlloc = alloc + } + + // Create all the channels. + for len(cs.channels) < channelsPerClient { + res := &channel{ + done: make(chan struct{}), + } + + res.desc, err = cs.channelAlloc.Allocate(channelSize) + if err != nil { + return err + } + if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil { + return err + } + + socks, err := fdchannel.NewConnectedSockets() + if err != nil { + res.data.Destroy() // Cleanup. + return err + } + res.fds.Init(socks[0]) + res.client = fd.New(socks[1]) + + cs.channels = append(cs.channels, res) + + // Start servicing the channel. + // + // When we call stop, we will close all the channels and these + // routines should finish. We need the wait group to ensure + // that active handlers are actually finished before cleanup. + cs.channelWg.Add(1) + go func() { // S/R-SAFE: Server side. + defer cs.channelWg.Done() + if err := res.service(cs); err != nil { + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } + } + }() + } + + return nil +} + +// lookupChannel looks up the channel with given id. +// +// The function returns nil if no such channel is available. +func (cs *connState) lookupChannel(id uint32) *channel { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + if id >= uint32(len(cs.channels)) { + return nil + } + return cs.channels[id] +} + +// handle handles a single message. +func (cs *connState) handle(m message) (r message) { + defer func() { + if r == nil { + // Don't allow a panic to propagate. + recover() + + // Include a useful log message. + log.Warningf("panic in handler: %s", debug.Stack()) + + // Wrap in an EFAULT error; we don't really have a + // better way to describe this kind of error. It will + // usually manifest as a result of the test framework. + r = newErr(syscall.EFAULT) + } + }() + if handler, ok := m.(handler); ok { + // Call the message handler. + r = handler.handle(cs) + } else { + // Produce an ENOSYS error. + r = newErr(syscall.ENOSYS) + } + return +} + // handleRequest handles a single request. // // The recvDone channel is signaled when recv is done (with a error if @@ -428,41 +545,20 @@ func (cs *connState) handleRequest() { } // Handle the message. - var r message // r is the response. - defer func() { - if r == nil { - // Don't allow a panic to propagate. - recover() - - // Include a useful log message. - log.Warningf("panic in handler: %s", debug.Stack()) + r := cs.handle(m) - // Wrap in an EFAULT error; we don't really have a - // better way to describe this kind of error. It will - // usually manifest as a result of the test framework. - r = newErr(syscall.EFAULT) - } + // Clear the tag before sending. That's because as soon as this hits + // the wire, the client can legally send the same tag. + cs.ClearTag(tag) - // Clear the tag before sending. That's because as soon as this - // hits the wire, the client can legally send another message - // with the same tag. - cs.ClearTag(tag) + // Send back the result. + cs.sendMu.Lock() + err = send(cs.conn, tag, r) + cs.sendMu.Unlock() + cs.sendDone <- err - // Send back the result. - cs.sendMu.Lock() - err = send(cs.conn, tag, r) - cs.sendMu.Unlock() - cs.sendDone <- err - }() - if handler, ok := m.(handler); ok { - // Call the message handler. - r = handler.handle(cs) - } else { - // Produce an ENOSYS error. - r = newErr(syscall.ENOSYS) - } + // Return the message to the cache. msgRegistry.put(m) - m = nil // 'm' should not be touched after this point. } func (cs *connState) handleRequests() { @@ -477,7 +573,27 @@ func (cs *connState) stop() { close(cs.recvDone) close(cs.sendDone) - for _, fidRef := range cs.fids { + // Free the channels. + cs.channelMu.Lock() + for _, ch := range cs.channels { + ch.Shutdown() + } + cs.channelWg.Wait() + for _, ch := range cs.channels { + ch.Close() + } + cs.channels = nil // Clear. + cs.channelMu.Unlock() + + // Free the channel memory. + if cs.channelAlloc != nil { + cs.channelAlloc.Destroy() + } + + // Close all remaining fids. + for fid, fidRef := range cs.fids { + delete(cs.fids, fid) + // Drop final reference in the FID table. Note this should // always close the file, since we've ensured that there are no // handlers running via the wait for Pending => 0 below. @@ -510,7 +626,7 @@ func (cs *connState) service() error { for i := 0; i < pending; i++ { <-cs.sendDone } - return err + return nil } // This handler is now pending. diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 5648df589..6e8b4bbcd 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -54,7 +54,10 @@ const ( headerLength uint32 = 7 // maximumLength is the largest possible message. - maximumLength uint32 = 4 * 1024 * 1024 + maximumLength uint32 = 1 << 20 + + // DefaultMessageSize is a sensible default. + DefaultMessageSize uint32 = 64 << 10 // initialBufferLength is the initial data buffer we allocate. initialBufferLength uint32 = 64 diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..233f825e3 --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,243 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + active bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, err + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.Encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + return ch.data.SendRecv(ssz) +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Copy from the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } + ch.buf.data = ch.buf.data[:fs] + } + + r.Decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return nil, syscall.Errno(rlerr.Error) + } + + return r, nil +} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index cdb3bc841..2f50ff3ea 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) { t.Fatalf("unable to create file: %v", err) } - if err := send(client, Tag(1), &Rlopen{File: f}); err != nil { + rlopen := &Rlopen{} + rlopen.SetFilePayload(f) + if err := send(client, Tag(1), rlopen); err != nil { t.Fatalf("send got err %v expected nil", err) } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index c2a2885ae..36a694c58 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 7 + highestSupportedVersion uint32 = 9 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -148,3 +148,16 @@ func VersionSupportsMultiUser(v uint32) bool { func versionSupportsTallocate(v uint32) bool { return v >= 7 } + +// versionSupportsFlipcall returns true if version v supports IPC channels from +// the flipcall package. Note that these must be negotiated, but this version +// string indicates that such a facility exists. +func versionSupportsFlipcall(v uint32) bool { + return v >= 8 +} + +// VersionSupportsOpenTruncateFlag returns true if version v supports +// passing the OpenTruncate flag to Tlopen. +func VersionSupportsOpenTruncateFlag(v uint32) bool { + return v >= 9 +} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 697e7a2f4..078f084b2 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 30ec8e6e2..38cea9be3 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index e340d9f98..4f4b70fef 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 9c08452fc..7ad59dfd7 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "weak_ref_list", diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index d1024e49d..af94e944d 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") package(licenses = ["notice"]) diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index c7503f2cc..fc36efa23 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -199,6 +199,10 @@ func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string { return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx) } +func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string { + return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name) +} + func checkArgsLabel(sysno uintptr) string { return fmt.Sprintf("checkArgs_%v", sysno) } @@ -223,6 +227,19 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case GreaterThan: + labelGood := fmt.Sprintf("gt%v", i) + high, low := uint32(a>>32), uint32(a) + // assert arg_high < high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high > high + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + // arg_low < low + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index 29eec8db1..84c841d7f 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -49,6 +49,9 @@ func (a AllowAny) String() (s string) { // AllowValue specifies a value that needs to be strictly matched. type AllowValue uintptr +// GreaterThan specifies a value that needs to be strictly smaller. +type GreaterThan uintptr + func (a AllowValue) String() (s string) { return fmt.Sprintf("%#x ", uintptr(a)) } diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 353686ed3..abbee7051 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -340,6 +340,54 @@ func TestBasic(t *testing.T) { }, }, }, + { + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThan(0xf), + GreaterThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + specs: []spec{ + { + desc: "GreaterThan: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, } { instrs, err := BuildProgram(test.ruleSets, test.defaultAction) if err != nil { diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index 48413f1fb..da6b9eaaf 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -38,7 +38,7 @@ func main() { syscall.SYS_CLONE: {}, syscall.SYS_CLOSE: {}, syscall.SYS_DUP: {}, - syscall.SYS_DUP2: {}, + syscall.SYS_DUP3: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_WAIT: {}, diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go index 0a3d92854..be328db12 100644 --- a/pkg/seccomp/seccomp_unsafe.go +++ b/pkg/seccomp/seccomp_unsafe.go @@ -35,7 +35,7 @@ type sockFprog struct { //go:nosplit func SetFilter(instrs []linux.BPFInstruction) syscall.Errno { // PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 { return errno } diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index f38fb39f3..22abdc69f 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 700385907..1b487b887 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -1,10 +1,10 @@ +load("//tools/go_generics:defs.bzl", "go_template") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) -load("//tools/go_generics:defs.bzl", "go_template") - go_template( name = "generic_range", srcs = ["range.go"], diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 694486296..a27c35e21 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,12 +1,12 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package( default_visibility = ["//visibility:private"], licenses = ["notice"], ) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "int_range", out = "int_range.go", diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 53989301f..2d6379c86 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -8,5 +8,7 @@ package_group( packages = [ "//pkg/sentry/...", "//runsc/...", + # Code generated by go_marshal relies on go_marshal libraries. + "//tools/go_marshal/...", ], ) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 7aace2d7b..18c73cc24 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,9 +1,9 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "arch", srcs = [ @@ -42,6 +42,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "registers_cc_proto", + visibility = ["//visibility:public"], + deps = [":registers_proto"], +) + go_proto_library( name = "registers_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 9e7db8b30..67daa6c24 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return c.Native(0), nil } @@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return nil } diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go index dfd62cbdb..23e009ef3 100644 --- a/pkg/sentry/context/context.go +++ b/pkg/sentry/context/context.go @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package context defines the sentry's Context type. +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. package context import ( + "context" + "time" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { type Context interface { log.Logger amutex.Sleeper + context.Context // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate @@ -72,19 +83,36 @@ type Context interface { // AddressSpace is activated. Normally activate is the same value as the // deactivate parameter passed to UninterruptibleSleepStart. UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} - // Value returns the value associated with this Context for key, or nil if - // no value is associated with key. Successive calls to Value with the same - // key returns the same result. - // - // A key identifies a specific value in a Context. Functions that wish to - // retrieve values from Context typically allocate a key in a global - // variable then use that key as the argument to Context.Value. A key can - // be any type that supports equality; packages should define keys as an - // unexported type to avoid collisions. - Value(key interface{}) interface{} +// Err returns nil. +func (NoopSleeper) Err() error { + return nil } +// logContext implements basic logging. type logContext struct { log.Logger NoopSleeper @@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} { return nil } -// NoopSleeper is a noop implementation of amutex.Sleeper and -// Context.UninterruptibleSleep* methods for anonymous embedding in other types -// that do not want to notify kernel.Task about sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - // bgContext is the context returned by context.Background. var bgContext = &logContext{Logger: log.Log()} diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index 3b6841b7e..581e7aa96 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "contexttest", testonly = 1, diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index bf802d1b6..5522cecd0 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 1f78d54a2..e1f2fea60 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -22,6 +22,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/urpc" ) @@ -56,6 +57,9 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD + + // Kernel is the kernel under profile. + Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a @@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { return err } + // Ensure all trace contexts are registered. + p.Kernel.RebuildTraceContexts() + p.traceFile = output return nil } @@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not start") + return errors.New("Execution tracing not started") } + // Similarly to the case above, if tasks have not ended traces, we will + // lose information. Thus we need to rebuild the tasks in order to have + // complete information. This will not lose information if multiple + // traces are overlapping. + p.Kernel.RebuildTraceContexts() + trace.Stop() p.traceFile.Close() p.traceFile = nil diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index c35faeb4c..ced51c66c 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error { } // Process contains information about a single process in a Sandbox. -// TODO(b/117881927): Implement TTY field. type Process struct { UID auth.KUID `json:"uid"` PID kernel.ThreadID `json:"pid"` // Parent PID - PPID kernel.ThreadID `json:"ppid"` + PPID kernel.ThreadID `json:"ppid"` + Threads []kernel.ThreadID `json:"threads"` // Processor utilization C int32 `json:"c"` + // TTY name of the process. Will be of the form "pts/N" if there is a + // TTY, or "?" if there is not. + TTY string `json:"tty"` // Start time STime string `json:"stime"` // CPU time @@ -285,18 +288,19 @@ type Process struct { } // ProcessListToTable prints a table with the following format: -// UID PID PPID C STIME TIME CMD -// 0 1 0 0 14:04 505262ns tail +// UID PID PPID C TTY STIME TIME CMD +// 0 1 0 0 pty/4 14:04 505262ns tail func ProcessListToTable(pl []*Process) string { var buf bytes.Buffer tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0) - fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD") + fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD") for _, d := range pl { - fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s", + fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s", d.UID, d.PID, d.PPID, d.C, + d.TTY, d.STime, d.Time, d.Cmd) @@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string { // ProcessListToJSON will return the JSON representation of ps. func ProcessListToJSON(pl []*Process) (string, error) { - b, err := json.Marshal(pl) + b, err := json.MarshalIndent(pl, "", " ") if err != nil { return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err) } @@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ts := k.TaskSet() now := k.RealtimeClock().Now() for _, tg := range ts.Root.ThreadGroups() { - pid := tg.PIDNamespace().IDOfThreadGroup(tg) + pidns := tg.PIDNamespace() + pid := pidns.IDOfThreadGroup(tg) + // If tg has already been reaped ignore it. if pid == 0 { continue @@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ppid := kernel.ThreadID(0) if p := tg.Leader().Parent(); p != nil { - ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup()) + ppid = pidns.IDOfThreadGroup(p.ThreadGroup()) } + threads := tg.MemberIDs(pidns) *out = append(*out, &Process{ - UID: tg.Leader().Credentials().EffectiveKUID, - PID: pid, - PPID: ppid, - STime: formatStartTime(now, tg.Leader().StartTime()), - C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), - Time: tg.CPUStats().SysTime.String(), - Cmd: tg.Leader().Name(), + UID: tg.Leader().Credentials().EffectiveKUID, + PID: pid, + PPID: ppid, + Threads: threads, + STime: formatStartTime(now, tg.Leader().StartTime()), + C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), + Time: tg.CPUStats().SysTime.String(), + Cmd: tg.Leader().Name(), + TTY: ttyName(tg.TTY()), }) } sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID }) @@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 { } return int32(percentCPU) } + +func ttyName(tty *kernel.TTY) string { + if tty == nil { + return "?" + } + return fmt.Sprintf("pts/%d", tty.Index) +} diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go index d8ada2694..0a88459b2 100644 --- a/pkg/sentry/control/proc_test.go +++ b/pkg/sentry/control/proc_test.go @@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) { }{ { pl: []*Process{}, - expected: "UID PID PPID C STIME TIME CMD", + expected: "UID PID PPID C TTY STIME TIME CMD", }, { pl: []*Process{ @@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) { PID: 0, PPID: 0, C: 0, + TTY: "?", STime: "0", Time: "0", Cmd: "zero", @@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) { PID: 1, PPID: 1, C: 1, + TTY: "pts/4", STime: "1", Time: "1", Cmd: "one", }, }, - expected: `UID PID PPID C STIME TIME CMD -0 0 0 0 0 0 zero -1 1 1 1 1 1 one`, + expected: `UID PID PPID C TTY STIME TIME CMD +0 0 0 0 ? 0 0 zero +1 1 1 1 pts/4 1 1 one`, }, } diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 7e8918722..1098ed777 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "device", diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index d7259b47b..c035ffff7 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_library( name = "fs", @@ -67,9 +68,9 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 80e106e6f..a0d9e8496 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "dev", srcs = [ diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index fbca06761..3cb73bd78 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { // Remove removes the given file or symlink. The root dirent is used to // resolve name, and must not be nil. -func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { +func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error { // Check the root. if root == nil { panic("Dirent.Remove: root must not be nil") @@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { return syscall.EISDIR + } else if dirPath { + return syscall.ENOTDIR } // Remove cannot remove a mount point. diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 884e3ff06..47bc72a88 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) { } d := f.Dirent - if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil { + if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { t.Fatalf("root.Remove(root, %q) failed: %v", name, err) } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index bf00b9c09..277ee4c31 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "fdpipe", diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index 8e4d839e1..577445148 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index a9d6d9301..358dc2be3 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "filetest", testonly = 1, diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go index 0fab876a9..4338ae1fa 100644 --- a/pkg/sentry/fs/flags.go +++ b/pkg/sentry/fs/flags.go @@ -64,6 +64,10 @@ type FileFlags struct { // NonSeekable indicates that file.offset isn't used. NonSeekable bool + + // Truncate indicates that the file should be truncated before opened. + // This is only applicable if the file is regular. + Truncate bool } // SettableFileFlags is a subset of FileFlags above that can be changed @@ -118,6 +122,9 @@ func (f FileFlags) ToLinux() (mask uint) { if f.LargeFile { mask |= linux.O_LARGEFILE } + if f.Truncate { + mask |= linux.O_TRUNC + } switch { case f.Read && f.Write: diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6499f87ac..b2e8d9c77 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "dirty_set_impl", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e239f12a5..b06a71cc2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -209,3 +209,29 @@ func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) { } delete(f.mappings, chunkStart) } + +// RegenerateMappings must be called when the file description mapped by f +// changes, to replace existing mappings of the previous file description. +func (f *HostFileMapper) RegenerateMappings(fd int) error { + f.mapsMu.Lock() + defer f.mapsMu.Unlock() + + for chunkStart, m := range f.mappings { + prot := syscall.PROT_READ + if m.writable { + prot |= syscall.PROT_WRITE + } + _, _, errno := syscall.Syscall6( + syscall.SYS_MMAP, + m.addr, + chunkSize, + uintptr(prot), + syscall.MAP_SHARED|syscall.MAP_FIXED, + uintptr(fd), + uintptr(chunkStart)) + if errno != 0 { + return errno + } + } + return nil +} diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index d2495cb83..30475f340 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -100,13 +100,30 @@ func (h *HostMappable) Translate(ctx context.Context, required, optional memmap. } // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -func (h *HostMappable) InvalidateUnsavable(ctx context.Context) error { +func (h *HostMappable) InvalidateUnsavable(_ context.Context) error { h.mu.Lock() h.mappings.InvalidateAll(memmap.InvalidateOpts{}) h.mu.Unlock() return nil } +// NotifyChangeFD must be called after the file description represented by +// CachedFileObject.FD() changes. +func (h *HostMappable) NotifyChangeFD() error { + // Update existing sentry mappings to refer to the new file description. + if err := h.hostFileMapper.RegenerateMappings(h.backingFile.FD()); err != nil { + return err + } + + // Shoot down existing application mappings of the old file description; + // they will be remapped with the new file description on demand. + h.mu.Lock() + defer h.mu.Unlock() + + h.mappings.InvalidateAll(memmap.InvalidateOpts{}) + return nil +} + // MapInternal implements platform.File.MapInternal. func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) @@ -144,7 +161,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error { mask := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: newSize} - if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil { + if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil { return err } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index d404a79d4..798920d18 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -140,12 +140,16 @@ type CachedFileObject interface { // WriteFromBlocksAt may return a partial write without an error. WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) - // SetMaskedAttributes sets the attributes in attr that are true in mask - // on the backing file. + // SetMaskedAttributes sets the attributes in attr that are true in + // mask on the backing file. If the mask contains only ATime or MTime + // and the CachedFileObject has an FD to the file, then this operation + // is a noop unless forceSetTimestamps is true. This avoids an extra + // RPC to the gofer in the open-read/write-close case, when the + // timestamps on the file will be updated by the host kernel for us. // // SetMaskedAttributes may be called at any point, regardless of whether // the file was opened. - SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error + SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error // Allocate allows the caller to reserve disk space for the inode. // It's equivalent to fallocate(2) with 'mode=0'. @@ -224,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Perms: true} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil { return false } c.attr.Perms = perms @@ -246,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, UID: owner.UID.Ok(), GID: owner.GID.Ok(), } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil { return err } if owner.UID.Ok() { @@ -282,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In AccessTime: !ts.ATimeOmit, ModificationTime: !ts.MTimeOmit, } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil { + // Call SetMaskedAttributes with forceSetTimestamps = true to make sure + // the timestamp is updated. + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil { return err } if !ts.ATimeOmit { @@ -305,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: size} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil { c.dataMu.Unlock() return err } @@ -394,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) c.dirtyAttr.Size = false // Write out cached attributes. - if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil { c.attrMu.Unlock() return err } @@ -953,6 +959,23 @@ func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error return nil } +// NotifyChangeFD must be called after the file description represented by +// CachedFileObject.FD() changes. +func (c *CachingInodeOperations) NotifyChangeFD() error { + // Update existing sentry mappings to refer to the new file description. + if err := c.hostFileMapper.RegenerateMappings(c.backingFile.FD()); err != nil { + return err + } + + // Shoot down existing application mappings of the old file description; + // they will be remapped with the new file description on demand. + c.mapsMu.Lock() + defer c.mapsMu.Unlock() + + c.mappings.InvalidateAll(memmap.InvalidateOpts{}) + return nil +} + // Evict implements pgalloc.EvictableMemoryUser.Evict. func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) { c.mapsMu.Lock() @@ -1021,7 +1044,6 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { } c.refs.MergeAdjacent(fr) c.dataMu.Unlock() - } // MapInternal implements platform.File.MapInternal. This is used when we diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index eb5730c35..129f314c8 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block return srcs.NumBytes(), nil } -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } @@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B return w.WriteFromBlocks(srcs) } -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 6b993928c..4a005c605 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "gofer", diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 9e2e412cd..7960b9c7b 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -214,28 +214,64 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr return entries, nil } +// maybeSync will call FSync on the file if either the cache policy or file +// flags require it. +func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error { + if n == 0 { + // Nothing to sync. + return nil + } + + if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { + // Call WriteOut directly, as some "writethrough" filesystems + // do not support sync. + return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) + } + + flags := file.Flags() + var syncType fs.SyncType + switch { + case flags.Direct || flags.Sync: + syncType = fs.SyncAll + case flags.DSync: + syncType = fs.SyncData + default: + // No need to sync. + return nil + } + + return f.Fsync(ctx, file, offset, offset+n, syncType) +} + // Write implements fs.FileOperations.Write. func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. return 0, syserror.EISDIR } - cp := f.inodeOperations.session().cachePolicy - if cp.useCachingInodeOps(file.Dirent.Inode) { - n, err := f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) - if err != nil { - return n, err - } - if cp.writeThrough(file.Dirent.Inode) { - // Write out the file. - err = f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) - } - return n, err + + var ( + n int64 + err error + ) + // The write is handled in different ways depending on the cache policy + // and availability of a host-mappable FD. + if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) { + n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) + } else if f.inodeOperations.fileState.hostMappable != nil { + n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) + } else { + n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) } - if f.inodeOperations.fileState.hostMappable != nil { - return f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) + + // We may need to sync the written bytes. + if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil { + // Sync failed. Report 0 bytes written, since none of them are + // guaranteed to have been synced. + return 0, syncErr } - return src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + + return n, err } // incrementReadCounters increments the read counters for the read starting at the given time. We @@ -273,7 +309,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO } // Fsync implements fs.FileOperations.Fsync. -func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error { +func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start, end int64, syncType fs.SyncType) error { switch syncType { case fs.SyncAll, fs.SyncData: if err := file.Dirent.Inode.WriteOut(ctx); err != nil { diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index 9aa68a70e..bb8312849 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -28,8 +28,14 @@ func (f *fileOperations) afterLoad() { // Manually load the open handles. var err error + + // The file may have been opened with Truncate, but we don't + // want to re-open it with Truncate or we will lose data. + flags := f.flags + flags.Truncate = false + // TODO(b/38173783): Context is not plumbed to save/restore. - f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags) + f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) } diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go index 8f8ab5d29..cf96dd9fa 100644 --- a/pkg/sentry/fs/gofer/fs.go +++ b/pkg/sentry/fs/gofer/fs.go @@ -58,6 +58,11 @@ const ( // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to // true. limitHostFDTranslationKey = "limit_host_fd_translation" + + // overlayfsStaleRead if present closes cached readonly file after the first + // write. This is done to workaround a limitation of overlayfs in kernels + // before 4.19 where open FDs are not updated after the file is copied up. + overlayfsStaleRead = "overlayfs_stale_read" ) // defaultAname is the default attach name. @@ -145,6 +150,7 @@ type opts struct { version string privateunixsocket bool limitHostFDTranslation bool + overlayfsStaleRead bool } // options parses mount(2) data into structured options. @@ -247,6 +253,11 @@ func options(data string) (opts, error) { delete(options, limitHostFDTranslationKey) } + if _, ok := options[overlayfsStaleRead]; ok { + o.overlayfsStaleRead = true + delete(options, overlayfsStaleRead) + } + // Fail to attach if the caller wanted us to do something that we // don't support. if len(options) > 0 { diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 27eeae3d9..b86c49b39 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -39,14 +39,22 @@ type handles struct { // Host is an *fd.FD handle. May be nil. Host *fd.FD + + // isHostBorrowed tells whether 'Host' is owned or borrowed. If owned, it's + // closed on destruction, otherwise it's released. + isHostBorrowed bool } // DecRef drops a reference on handles. func (h *handles) DecRef() { h.DecRefWithDestructor(func() { if h.Host != nil { - if err := h.Host.Close(); err != nil { - log.Warningf("error closing host file: %v", err) + if h.isHostBorrowed { + h.Host.Release() + } else { + if err := h.Host.Close(); err != nil { + log.Warningf("error closing host file: %v", err) + } } } // FIXME(b/38173783): Context is not plumbed here. @@ -56,7 +64,7 @@ func (h *handles) DecRef() { }) } -func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) { +func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) { _, newFile, err := file.walk(ctx, nil) if err != nil { return nil, err @@ -73,6 +81,9 @@ func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*han default: panic("impossible fs.FileFlags") } + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) { + p9flags |= p9.OpenTruncate + } hostFile, _, _, err := newFile.open(ctx, p9flags) if err != nil { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 95b064aea..91263ebdc 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -100,7 +100,7 @@ type inodeFileState struct { // true. // // Once readHandles becomes non-nil, it can't be changed until - // inodeFileState.Release(), because of a defect in the + // inodeFileState.Release()*, because of a defect in the // fsutil.CachedFileObject interface: there's no way for the caller of // fsutil.CachedFileObject.FD() to keep the returned FD open, so if we // racily replace readHandles after inodeFileState.FD() has returned @@ -108,6 +108,9 @@ type inodeFileState struct { // FD. writeHandles can be changed if writeHandlesRW is false, since // inodeFileState.FD() can't return a write-only FD, but can't be changed // if writeHandlesRW is true for the same reason. + // + // * There is one notable exception in recreateReadHandles(), where it dup's + // the FD and invalidates the page cache. readHandles *handles `state:"nosave"` writeHandles *handles `state:"nosave"` writeHandlesRW bool `state:"nosave"` @@ -175,48 +178,135 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) // getHandles returns a set of handles for a new file using i opened with the // given flags. -func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags) (*handles, error) { +func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) { if !i.canShareHandles() { - return newHandles(ctx, i.file, flags) + return newHandles(ctx, i.s.client, i.file, flags) } + i.handlesMu.Lock() - defer i.handlesMu.Unlock() - // Do we already have usable shared handles? - if flags.Write { + h, invalidate, err := i.getHandlesLocked(ctx, flags) + i.handlesMu.Unlock() + + if invalidate { + cache.NotifyChangeFD() + if i.hostMappable != nil { + i.hostMappable.NotifyChangeFD() + } + } + + return h, err +} + +// getHandlesLocked returns a pointer to cached handles and a boolean indicating +// whether previously open read handle was recreated. Host mappings must be +// invalidated if so. +func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) { + // Check if we are able to use cached handles. + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) { + // If we are truncating (and the gofer supports it), then we + // always need a new handle. Don't return one from the cache. + } else if flags.Write { if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) { + // File is opened for writing, and we have cached write + // handles that we can use. i.writeHandles.IncRef() - return i.writeHandles, nil + return i.writeHandles, false, nil } } else if i.readHandles != nil { + // File is opened for reading and we have cached handles. i.readHandles.IncRef() - return i.readHandles, nil + return i.readHandles, false, nil } - // No; get new handles and cache them for future sharing. - h, err := newHandles(ctx, i.file, flags) + + // Get new handles and cache them for future sharing. + h, err := newHandles(ctx, i.s.client, i.file, flags) if err != nil { - return nil, err + return nil, false, err + } + + // Read handles invalidation is needed if: + // - Mount option 'overlayfs_stale_read' is set + // - Read handle is open: nothing to invalidate otherwise + // - Write handle is not open: file was not open for write and is being open + // for write now (will trigger copy up in overlayfs). + invalidate := false + if i.s.overlayfsStaleRead && i.readHandles != nil && i.writeHandles == nil && flags.Write { + if err := i.recreateReadHandles(ctx, h, flags); err != nil { + return nil, false, err + } + invalidate = true } i.setSharedHandlesLocked(flags, h) - return h, nil + return h, invalidate, nil +} + +func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handles, flags fs.FileFlags) error { + h := writer + if !flags.Read { + // Writer can't be used for read, must create a new handle. + var err error + h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true}) + if err != nil { + return err + } + defer h.DecRef() + } + + if i.readHandles.Host == nil { + // If current readHandles doesn't have a host FD, it can simply be replaced. + i.readHandles.DecRef() + + h.IncRef() + i.readHandles = h + return nil + } + + if h.Host == nil { + // Current read handle has a host FD and can't be replaced with one that + // doesn't, because it breaks fsutil.CachedFileObject.FD() contract. + log.Warningf("Read handle can't be invalidated, reads may return stale data") + return nil + } + + // Due to a defect in the fsutil.CachedFileObject interface, + // readHandles.Host.FD() may be used outside locks, making it impossible to + // reliably close it. To workaround it, we dup the new FD into the old one, so + // operations on the old will see the new data. Then, make the new handle take + // ownereship of the old FD and mark the old readHandle to not close the FD + // when done. + if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil { + return err + } + + h.Host.Close() + h.Host = fd.New(i.readHandles.Host.FD()) + i.readHandles.isHostBorrowed = true + i.readHandles.DecRef() + + h.IncRef() + i.readHandles = h + return nil } // ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt. func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { i.handlesMu.RLock() - defer i.handlesMu.RUnlock() - return i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts) + n, err := i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts) + i.handlesMu.RUnlock() + return n, err } // WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt. func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { i.handlesMu.RLock() - defer i.handlesMu.RUnlock() - return i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs) + n, err := i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs) + i.handlesMu.RUnlock() + return n, err } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { - if i.skipSetAttr(mask) { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error { + if i.skipSetAttr(mask, forceSetTimestamps) { return nil } as, ans := attr.AccessTime.Unix() @@ -251,13 +341,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa // when: // - Mask is empty // - Mask contains only attributes that cannot be set in the gofer -// - Mask contains only atime and/or mtime, and host FD exists +// - forceSetTimestamps is false and mask contains only atime and/or mtime +// and host FD exists // // Updates to atime and mtime can be skipped because cached value will be // "close enough" to host value, given that operation went directly to host FD. // Skipping atime updates is particularly important to reduce the number of // operations sent to the Gofer for readonly files. -func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { +func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool { // First remove attributes that cannot be updated. cpy := mask cpy.Type = false @@ -277,6 +368,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { return false } + // If forceSetTimestamps was passed, then we cannot skip. + if forceSetTimestamps { + return false + } + + // Skip if we have a host FD. i.handlesMu.RLock() defer i.handlesMu.RUnlock() return (i.readHandles != nil && i.readHandles.Host != nil) || @@ -442,7 +539,7 @@ func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (* } func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - h, err := i.fileState.getHandles(ctx, flags) + h, err := i.fileState.getHandles(ctx, flags, i.cachingInodeOps) if err != nil { return nil, err } diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 8c17603f8..c09f3b71c 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -234,6 +234,8 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, if err != nil { return nil, err } + // We're not going to use newFile after return. + defer newFile.close(ctx) // Stabilize the endpoint map while creation is in progress. unlock := i.session().endpoints.lock() @@ -254,7 +256,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // Get the attributes of the file to create inode key. qid, mask, attr, err := getattr(ctx, newFile) if err != nil { - newFile.close(ctx) return nil, err } @@ -270,7 +271,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // cloned and re-opened multiple times after creation. _, unopened, err := i.fileState.file.walk(ctx, []string{name}) if err != nil { - newFile.close(ctx) return nil, err } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 50da865c1..4e358a46a 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -122,6 +122,10 @@ type session struct { // CachingInodeOperations created by the session. limitHostFDTranslation bool + // overlayfsStaleRead when set causes the readonly handle to be invalidated + // after file is open for write. + overlayfsStaleRead bool + // connID is a unique identifier for the session connection. connID string `state:"wait"` @@ -139,9 +143,9 @@ type session struct { // socket files. This allows unix domain sockets to be used with paths that // belong to a gofer. // - // TODO(b/77154739): there are few possible races with someone stat'ing the - // file and another deleting it concurrently, where the file will not be - // reported as socket file. + // TODO(gvisor.dev/issue/1200): there are few possible races with someone + // stat'ing the file and another deleting it concurrently, where the file + // will not be reported as socket file. endpoints *endpointMaps `state:"wait"` } @@ -257,6 +261,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF aname: o.aname, superBlockFlags: superBlockFlags, limitHostFDTranslation: o.limitHostFDTranslation, + overlayfsStaleRead: o.overlayfsStaleRead, mounter: mounter, } s.EnableLeakCheck("gofer.session") diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index b1080fb1a..23daeb528 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "host", @@ -20,6 +21,8 @@ go_library( "socket_unsafe.go", "tty.go", "util.go", + "util_amd64_unsafe.go", + "util_arm64_unsafe.go", "util_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 894ab01f0..a6e4a09e3 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error { if mask.Empty() { return nil } @@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err return unstableAttr(i.mops, &s), nil } -// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. +// Allocate implements fsutil.CachedFileObject.Allocate. func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error { return syscall.Fallocate(i.FD(), 0, offset, length) } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 2392787cb..107336a3e 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -385,3 +385,6 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { func (c *ConnectedEndpoint) Release() { c.ref.DecRefWithDestructor(c.close) } + +// CloseUnread implements transport.ConnectedEndpoint.CloseUnread. +func (c *ConnectedEndpoint) CloseUnread() {} diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 2526412a4..90331e3b2 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -43,12 +43,15 @@ type TTYFileOperations struct { // fgProcessGroup is the foreground process group that is currently // connected to this TTY. fgProcessGroup *kernel.ProcessGroup + + termios linux.KernelTermios } // newTTYFile returns a new fs.File that wraps a TTY FD. func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File { return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{ fileOperations: fileOperations{iops: iops}, + termios: linux.DefaultSlaveTermios, }) } @@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme t.mu.Lock() defer t.mu.Unlock() - // Are we allowed to do the write? - if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { - return 0, err + // Check whether TOSTOP is enabled. This corresponds to the check in + // drivers/tty/n_tty.c:n_tty_write(). + if t.termios.LEnabled(linux.TOSTOP) { + if err := t.checkChange(ctx, linux.SIGTTOU); err != nil { + return 0, err + } } return t.fileOperations.Write(ctx, file, src, offset) } @@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) + if err == nil { + t.termios.FromTermios(termios) + } return 0, err case linux.TIOCGPGRP: diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index bad61a9a1..e37e687c6 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -155,7 +155,7 @@ func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr { AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec), ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec), StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec), - Links: s.Nlink, + Links: uint64(s.Nlink), } } diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go new file mode 100644 index 000000000..66da6e9f5 --- /dev/null +++ b/pkg/sentry/fs/host/util_amd64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_NEWFSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go new file mode 100644 index 000000000..e8cb94aeb --- /dev/null +++ b/pkg/sentry/fs/host/util_arm64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_FSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go index 2b76f1065..3ab36b088 100644 --- a/pkg/sentry/fs/host/util_unsafe.go +++ b/pkg/sentry/fs/host/util_unsafe.go @@ -116,22 +116,3 @@ func setTimestamps(fd int, ts fs.TimeSpec) error { } return nil } - -func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { - var stat syscall.Stat_t - namePtr, err := syscall.BytePtrFromString(name) - if err != nil { - return stat, err - } - _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, - uintptr(fd), - uintptr(unsafe.Pointer(namePtr)), - uintptr(unsafe.Pointer(&stat)), - uintptr(flags), - 0, 0) - if errno != 0 { - return stat, errno - } - return stat, nil -} diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index f4ddfa406..91e2fde2f 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -270,6 +270,14 @@ func (i *Inode) Getxattr(name string) (string, error) { return i.InodeOperations.Getxattr(i, name) } +// Setxattr calls i.InodeOperations.Setxattr with i as the Inode. +func (i *Inode) Setxattr(name, value string) error { + if i.overlay != nil { + return overlaySetxattr(i.overlay, name, value) + } + return i.InodeOperations.Setxattr(i, name, value) +} + // Listxattr calls i.InodeOperations.Listxattr with i as the Inode. func (i *Inode) Listxattr() (map[string]struct{}, error) { if i.overlay != nil { @@ -344,6 +352,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error // Truncate calls i.InodeOperations.Truncate with i as the Inode. func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error { + if IsDir(i.StableAttr) { + return syserror.EISDIR + } + if i.overlay != nil { return overlayTruncate(ctx, i.overlay, d, size) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 246b97161..13d11e001 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "fmt" "strings" "gvisor.dev/gvisor/pkg/abi/linux" @@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) { + // Sanity check. + if parent.Inode.overlay == nil { + panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations)) + } + // Dirent.Create takes renameMu if the Inode is an overlay Inode. if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err @@ -430,7 +436,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena } func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { - if err := copyUp(ctx, parent); err != nil { + if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err } @@ -456,7 +462,9 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri inode.DecRef() return nil, err } - return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil + // Use the parent's MountSource, since that corresponds to the overlay, + // and not the upper filesystem. + return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil } func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint { @@ -544,6 +552,11 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) { return s, err } +// TODO(b/146028302): Support setxattr for overlayfs. +func overlaySetxattr(o *overlayEntry, name, value string) error { + return syserror.EOPNOTSUPP +} + func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) { o.copyMu.RLock() defer o.copyMu.RUnlock() diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 08d7c0c57..8d62642e7 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "lock_range", diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 1d3ff39e0..25573e986 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syncutil" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/third_party/gvsync" ) // The virtual filesystem implements an overlay configuration. For a high-level @@ -199,7 +199,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu gvsync.DowngradableRWMutex `state:"nosave"` + dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index c7599d1f6..75cbb0622 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "proc", @@ -51,6 +52,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/tcpip/header", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 5e28982c5..402919924 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "io" + "reflect" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -33,6 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // newNet creates a new proc net entry. @@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a @@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo // (ClockGetres returns 1ns resolution). "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), - "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), @@ -66,7 +70,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo if s.SupportsIPv6() { contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) } } @@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netSnmp implements seqfile.SeqSource for /proc/net/snmp. +// +// +stateify savable +type netSnmp struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netSnmp) NeedsUpdate(generation int64) bool { + return true +} + +type snmpLine struct { + prefix string + header string +} + +var snmp = []snmpLine{ + { + prefix: "Ip", + header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", + }, + { + prefix: "Icmp", + header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", + }, + { + prefix: "IcmpMsg", + }, + { + prefix: "Tcp", + header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", + }, + { + prefix: "Udp", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, + { + prefix: "UdpLite", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, +} + +func toSlice(a interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(a)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + +func sprintSlice(s []uint64) string { + if len(s) == 0 { + return "" + } + r := fmt.Sprint(s) + return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's +// net/core/net-procfs.c:dev_seq_show. +func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + contents := make([]string, 0, len(snmp)*2) + types := []interface{}{ + &inet.StatSNMPIP{}, + &inet.StatSNMPICMP{}, + nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. + &inet.StatSNMPTCP{}, + &inet.StatSNMPUDP{}, + &inet.StatSNMPUDPLite{}, + } + for i, stat := range types { + line := snmp[i] + if stat == nil { + contents = append( + contents, + fmt.Sprintf("%s:\n", line.prefix), + fmt.Sprintf("%s:\n", line.prefix), + ) + continue + } + if err := n.s.Statistics(stat, line.prefix); err != nil { + if err == syserror.EOPNOTSUPP { + log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } else { + log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } + } + var values string + if line.prefix == "Tcp" { + tcp := stat.(*inet.StatSNMPTCP) + // "Tcp" needs special processing because MaxConn is signed. RFC 2012. + values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + } else { + values = sprintSlice(toSlice(stat)) + } + contents = append( + contents, + fmt.Sprintf("%s: %s\n", line.prefix, line.header), + fmt.Sprintf("%s: %s\n", line.prefix, values), + ) + } + + data := make([]seqfile.SeqData, 0, len(snmp)*2) + for _, l := range contents { + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)}) + } + + return data, 0 +} + +// netRoute implements seqfile.SeqSource for /proc/net/route. +// +// +stateify savable +type netRoute struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netRoute) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. +func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + interfaces := n.s.Interfaces() + contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"} + for _, rt := range n.s.RouteTable() { + // /proc/net/route only includes ipv4 routes. + if rt.Family != linux.AF_INET { + continue + } + + // /proc/net/route does not include broadcast or multicast routes. + if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { + continue + } + + iface, ok := interfaces[rt.OutputInterface] + if !ok || iface.Name == "lo" { + continue + } + + var ( + gw uint32 + prefix uint32 + flags = linux.RTF_UP + ) + if len(rt.GatewayAddr) == header.IPv4AddressSize { + flags |= linux.RTF_GATEWAY + gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + } + if len(rt.DstAddr) == header.IPv4AddressSize { + prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + } + l := fmt.Sprintf( + "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", + iface.Name, + prefix, + gw, + flags, + 0, // RefCnt. + 0, // Use. + 0, // Metric. + (uint32(1)<<rt.DstLen)-1, + 0, // MTU. + 0, // Window. + 0, // RTT. + ) + contents = append(contents, l) + } + + var data []seqfile.SeqData + for _, l := range contents { + l = fmt.Sprintf("%-127s\n", l) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)}) + } + + return data, 0 +} + // netUnix implements seqfile.SeqSource for /proc/net/unix. // // +stateify savable @@ -310,44 +501,51 @@ func networkToHost16(n uint16) uint16 { return usermem.ByteOrder.Uint16(buf[:]) } -func writeInetAddr(w io.Writer, a linux.SockAddrInet) { - // linux.SockAddrInet.Port is stored in the network byte order and is - // printed like a number in host byte order. Note that all numbers in host - // byte order are printed with the most-significant byte first when - // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. - port := networkToHost16(a.Port) - - // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order - // (i.e. most-significant byte in index 0). Linux represents this as a - // __be32 which is a typedef for an unsigned int, and is printed with - // %X. This means that for a little-endian machine, Linux prints the - // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, - // which makes it have the equivalent encoding to a __be32 on a little - // endian machine. Note that this operation is a no-op on a big endian - // machine. Then similar to Linux, we format it with %X, which will print - // the most-significant byte of the __be32 address first, which is now - // actually the least-significant byte of the original address in - // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) - - fmt.Fprintf(w, "%08X:%04X ", addr, port) -} +func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { + switch family { + case linux.AF_INET: + var a linux.SockAddrInet + if i != nil { + a = *i.(*linux.SockAddrInet) + } -// netTCP implements seqfile.SeqSource for /proc/net/tcp. -// -// +stateify savable -type netTCP struct { - k *kernel.Kernel -} + // linux.SockAddrInet.Port is stored in the network byte order and is + // printed like a number in host byte order. Note that all numbers in host + // byte order are printed with the most-significant byte first when + // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. + port := networkToHost16(a.Port) + + // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order + // (i.e. most-significant byte in index 0). Linux represents this as a + // __be32 which is a typedef for an unsigned int, and is printed with + // %X. This means that for a little-endian machine, Linux prints the + // least-significant byte of the address first. To emulate this, we first + // invert the byte order for the address using usermem.ByteOrder.Uint32, + // which makes it have the equivalent encoding to a __be32 on a little + // endian machine. Note that this operation is a no-op on a big endian + // machine. Then similar to Linux, we format it with %X, which will print + // the most-significant byte of the __be32 address first, which is now + // actually the least-significant byte of the original address in + // linux.SockAddrInet.Addr on little endian machines, due to the conversion. + addr := usermem.ByteOrder.Uint32(a.Addr[:]) + + fmt.Fprintf(w, "%08X:%04X ", addr, port) + case linux.AF_INET6: + var a linux.SockAddrInet6 + if i != nil { + a = *i.(*linux.SockAddrInet6) + } -// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. -func (*netTCP) NeedsUpdate(generation int64) bool { - return true + port := networkToHost16(a.Port) + addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) + } } -// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. -func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { +func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) { // t may be nil here if our caller is not part of a task goroutine. This can // happen for example if we're here for "sentryctl cat". When t is nil, // degrade gracefully and retrieve what we can. @@ -358,7 +556,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se } var buf bytes.Buffer - for _, se := range n.k.ListSockets() { + for _, se := range k.ListSockets() { s := se.Sock.Get() if s == nil { log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) @@ -369,7 +567,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se if !ok { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } - if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() // Not tcp4 sockets. continue @@ -384,22 +582,22 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "%4d: ", se.ID) // Field: local_adddress. - var localAddr linux.SockAddrInet + var localAddr linux.SockAddr if t != nil { if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) + localAddr = local } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, fa, localAddr) // Field: rem_address. - var remoteAddr linux.SockAddrInet + var remoteAddr linux.SockAddr if t != nil { if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) + remoteAddr = remote } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, fa, remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) @@ -465,7 +663,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se data := []seqfile.SeqData{ { - Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"), + Buf: header, Handle: n, }, { @@ -476,6 +674,42 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netTCP implements seqfile.SeqSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header) +} + +// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6. +// +// +stateify savable +type netTCP6 struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP6) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header) +} + // netUDP implements seqfile.SeqSource for /proc/net/udp. // // +stateify savable @@ -529,7 +763,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se localAddr = *local.(*linux.SockAddrInet) } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, linux.AF_INET, &localAddr) // Field: rem_address. var remoteAddr linux.SockAddrInet @@ -538,7 +772,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se remoteAddr = *remote.(*linux.SockAddrInet) } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, linux.AF_INET, &remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 0ef13f2f5..56e92721e 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // But for whatever crazy reason, you can still walk to the given node. for _, tg := range rpf.iops.pidns.ThreadGroups() { if leader := tg.Leader(); leader != nil { - name := strconv.FormatUint(uint64(tg.ID()), 10) + name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10) m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice) names = append(names, name) } diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 20c3eefc8..fe7067be1 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "seqfile", diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f3b63dfc2..bd93f83fa 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -64,7 +64,7 @@ var _ fs.InodeOperations = (*tcpMemInode)(nil) func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode { tm := &tcpMemInode{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), s: s, dir: dir, } @@ -77,6 +77,11 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir return fs.NewInode(ctx, tm, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true @@ -168,14 +173,15 @@ func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error { // +stateify savable type tcpSack struct { + fsutil.SimpleFileInode + stack inet.Stack `state:"wait"` enabled *bool - fsutil.SimpleFileInode } func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { ts := &tcpSack{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), stack: s, } sattr := fs.StableAttr{ @@ -187,6 +193,11 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f return fs.NewInode(ctx, ts, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 87184ec67..9bf4b4527 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -67,29 +67,28 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode { +func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - // FIXME(b/123511468): create the correct io file for threads. - "io": newIO(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "ns": newNamespaceDir(t, msrc), "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, showSubtasks, p.pidns), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), "statm": newStatm(t, msrc), "status": newStatus(t, msrc, p.pidns), "uid_map": newUIDMap(t, msrc), } - if showSubtasks { + if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) } if len(p.cgroupControllers) > 0 { @@ -605,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(&buf, "Mems_allowed:\t1\n") + fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n") return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0 } @@ -619,8 +622,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) +func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { + if isThreadGroup { + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + } + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -639,7 +645,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se io.Accumulate(i.IOUsage()) var buf bytes.Buffer - fmt.Fprintf(&buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead) fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten) fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls) fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 516efcc4c..012cb3e44 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "ramfs", diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..311798811 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err == syserror.ENOSYS && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index 70fa3af89..25f0f124e 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "sys", srcs = [ diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index 1d80daeaf..a215c1b95 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "timerfd", srcs = ["timerfd.go"], diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 59403d9db..f8bf663bb 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i } // Notify implements ktime.TimerListener.Notify. -func (t *TimerOperations) Notify(exp uint64) { +func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { atomic.AddUint64(&t.val, exp) t.events.Notify(waiter.EventIn) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 8f7eb5757..59ce400c2 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "tmpfs", diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 159fb7c08..69089c8a8 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -324,7 +324,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(ctx, true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5e9327aec..95ad98cb0 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "tty", @@ -23,6 +24,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 1d128532b..2f639c823 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { // Release implements fs.InodeOperations.Release. func (d *dirInodeOperations) Release(ctx context.Context) { + d.mu.Lock() + defer d.mu.Unlock() + d.master.DecRef() if len(d.slaves) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 92ec1ca18..934828c12 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -76,6 +76,11 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn func (mi *masterInodeOperations) Release(ctx context.Context) { } +// Truncate implements fs.InodeOperations.Truncate. +func (masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // It allocates a new terminal. @@ -172,6 +177,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme return 0, mf.t.ld.windowSize(ctx, io, args) case linux.TIOCSWINSZ: return 0, mf.t.ld.setWindowSize(ctx, io, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -185,8 +203,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TCSETS, linux.TCSETSW, linux.TCSETSF, - linux.TIOCGPGRP, - linux.TIOCSPGRP, linux.TIOCGWINSZ, linux.TIOCSWINSZ, linux.TIOCSETD, @@ -200,8 +216,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TIOCEXCL, linux.TIOCNXCL, linux.TIOCGEXCL, - linux.TIOCNOTTY, - linux.TIOCSCTTY, linux.TIOCGSID, linux.TIOCGETD, linux.TIOCVHANGUP, diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index e30266404..2a51e6bab 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -72,6 +72,11 @@ func (si *slaveInodeOperations) Release(ctx context.Context) { si.t.DecRef() } +// Truncate implements fs.InodeOperations.Truncate. +func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // This may race with destruction of the terminal. If the terminal is gone, it @@ -152,9 +157,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - // TODO(b/129283598): Implement once we have support for job - // control. - return 0, nil + return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index b7cecb2ed..917f90cc0 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,7 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) // Terminal is a pseudoterminal. @@ -26,23 +29,100 @@ import ( type Terminal struct { refs.AtomicRefCount - // n is the terminal index. + // n is the terminal index. It is immutable. n uint32 - // d is the containing directory. + // d is the containing directory. It is immutable. d *dirInodeOperations - // ld is the line discipline of the terminal. + // ld is the line discipline of the terminal. It is immutable. ld *lineDiscipline + + // masterKTTY contains the controlling process of the master end of + // this terminal. This field is immutable. + masterKTTY *kernel.TTY + + // slaveKTTY contains the controlling process of the slave end of this + // terminal. This field is immutable. + slaveKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { termios := linux.DefaultSlaveTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + slaveKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t } + +// setControllingTTY makes tm the controlling terminal of the calling thread +// group. +func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setControllingTTY must be called from a task context") + } + + return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int()) +} + +// releaseControllingTTY removes tm as the controlling terminal of the calling +// thread group. +func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("releaseControllingTTY must be called from a task context") + } + + return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster)) +} + +// foregroundProcessGroup gets the process group ID of tm's foreground process. +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("foregroundProcessGroup must be called from a task context") + } + + ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster)) + if err != nil { + return 0, err + } + + // Write it out to *arg. + _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err +} + +// foregroundProcessGroup sets tm's foreground process. +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setForegroundProcessGroup must be called from a task context") + } + + // Read in the process group ID. + var pgid int32 + if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + + ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid)) + return uintptr(ret), err +} + +func (tm *Terminal) tty(isMaster bool) *kernel.TTY { + if isMaster { + return tm.masterKTTY + } + return tm.slaveKTTY +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 9e8ebb907..880b7bcd3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,8 +1,9 @@ -package(licenses = ["notice"]) - -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") +package(licenses = ["notice"]) + go_template_instance( name = "dirent_list", out = "dirent_list.go", @@ -37,6 +38,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/fd", + "//pkg/fspath", "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 9fddb4c4c..bfc46dfa6 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 10a8083a0..177ce2cb9 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -50,7 +50,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil { + if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } return func() { diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index cea89bcd9..a2d8c3ad6 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -154,7 +154,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) if n < toRead { return n, syserror.EIO } @@ -174,7 +174,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { var childPhyBlk uint32 - err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) + err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 213aa3919..181727ef7 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -87,12 +87,14 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) regFile := regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: getMockBMFileFize(), }, }, - dev: bytes.NewReader(mockDisk), blkSize: uint64(mockBMBlkSize), }, } diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 054fb42b6..a080cb189 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -41,16 +41,18 @@ func newDentry(in *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { - d.inode.decRef(vfsfs.Impl().(*filesystem)) +func (d *dentry) DecRef() { + // FIXME(b/134676337): filesystem.mu may not be locked as required by + // inode.decRef(). + d.inode.decRef() } diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index b51f3e18d..91802dc1e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba } if !cb.Handle(vfs.Dirent{ - Name: child.diskDirent.FileName(), - Type: fs.ToDirentType(childType), - Ino: uint64(child.diskDirent.Inode()), - Off: fd.off, + Name: child.diskDirent.FileName(), + Type: fs.ToDirentType(childType), + Ino: uint64(child.diskDirent.Inode()), + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil @@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // mmap(2) specifies that EACCESS should be returned for non-regular file fds. return syserror.EACCES } diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 907d35b7e..fcfaf5c3e 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "disklayout", diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index f10accafc..4b7d17dc6 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -40,14 +40,14 @@ var _ vfs.FilesystemType = (*FilesystemType)(nil) // Currently there are two ways of mounting an ext(2/3/4) fs: // 1. Specify a mount with our internal special MountType in the OCI spec. // 2. Expose the device to the container and mount it from application layer. -func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) { +func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) { if opts.InternalData == nil { // User mount call. // TODO(b/134676337): Open the device specified by `source` and return that. panic("unimplemented") } - // NewFilesystem call originated from within the sentry. + // GetFilesystem call originated from within the sentry. devFd, ok := opts.InternalData.(int) if !ok { return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device") @@ -91,8 +91,8 @@ func isCompatible(sb disklayout.SuperBlock) bool { return true } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { // TODO(b/134676337): Ensure that the user is mounting readonly. If not, // EACCESS should be returned according to mount(2). Filesystem independent // flags (like readonly) are currently not available in pkg/sentry/vfs. @@ -103,7 +103,7 @@ func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials } fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { return nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 63cf7aeaf..e9f756732 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -66,7 +66,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -147,55 +147,54 @@ func TestSeek(t *testing.T) { t.Fatalf("vfsfs.OpenAt failed: %v", err) } - if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { + if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { t.Errorf("expected seek position 0, got %d and error %v", n, err) } - stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{}) + stat, err := fd.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) } // We should be able to seek beyond the end of file. size := int64(stat.Size) - if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { + if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } - if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { + if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) } // Make sure negative offsets work with SEEK_CUR. - if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } // Make sure SEEK_END works with regular files. - switch fd.Impl().(type) { - case *regularFileFD: + if _, ok := fd.Impl().(*regularFileFD); ok { // Seek back to 0. - if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { + if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) } // Seek forward beyond EOF. - if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } } @@ -456,7 +455,7 @@ func TestRead(t *testing.T) { want := make([]byte, 1) for { n, err := f.Read(want) - fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) + fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) if diff := cmp.Diff(got, want); diff != "" { t.Errorf("file data mismatch (-want +got):\n%s", diff) @@ -464,7 +463,7 @@ func TestRead(t *testing.T) { // Make sure there is no more file data left after getting EOF. if n == 0 || err == io.EOF { - if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { + if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) } @@ -509,27 +508,27 @@ func TestIterDirents(t *testing.T) { } wantDirents := []vfs.Dirent{ - vfs.Dirent{ + { Name: ".", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "..", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "lost+found", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "file.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "bigfile.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "symlink.txt", Type: linux.DT_LNK, }, @@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) { } cb := &iterDirentsCb{} - if err = fd.Impl().IterDirents(ctx, cb); err != nil { + if err = fd.IterDirents(ctx, cb); err != nil { t.Fatalf("dir fd.IterDirents() failed: %v", err) } @@ -584,7 +583,7 @@ func TestIterDirents(t *testing.T) { // Ignore the inode number and offset of dirents because those are likely to // change as the underlying image changes. cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "Ino" || p.String() == "Off" + return p.String() == "Ino" || p.String() == "NextOff" }, cmp.Ignore()) if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { t.Errorf("dirents mismatch (-want +got):\n%s", diff) diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 38b68a2d3..3d3ebaca6 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -99,7 +99,7 @@ func (f *extentFile) buildExtTree() error { func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) { var header disklayout.ExtentHeader off := entry.PhysicalBlock() * f.regFile.inode.blkSize - err := readFromDisk(f.regFile.inode.dev, int64(off), &header) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header) if err != nil { return nil, err } @@ -115,7 +115,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla curEntry = &disklayout.ExtentIdx{} } - err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry) if err != nil { return nil, err } @@ -229,7 +229,7 @@ func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byt toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart)) if n < toRead { return n, syserror.EIO } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index 42d0a484b..a2382daa3 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -180,13 +180,15 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] mockExtentFile := &extentFile{ regFile: regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root), }, }, blkSize: mockExtentBlkSize, - dev: bytes.NewReader(mockDisk), }, }, } diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index a0065343b..5eca2b83f 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -36,16 +36,13 @@ type fileDescription struct { } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode + return fd.vfsfd.Dentry().Impl().(*dentry).inode } -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *fileDescription) OnClose() error { return nil } - // StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { return fd.flags, nil diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 2d15e8aaf..d7e87979a 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -20,6 +20,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -441,3 +442,46 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return "", err + } + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index e6c847a71..24249525c 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -16,7 +16,6 @@ package ext import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -42,13 +41,13 @@ type inode struct { // refs is a reference count. refs is accessed using atomic memory operations. refs int64 + // fs is the containing filesystem. + fs *filesystem + // inodeNum is the inode number of this inode on disk. This is used to // identify inodes within the ext filesystem. inodeNum uint32 - // dev represents the underlying device. Same as filesystem.dev. - dev io.ReaderAt - // blkSize is the fs data block size. Same as filesystem.sb.BlockSize(). blkSize uint64 @@ -81,10 +80,10 @@ func (in *inode) tryIncRef() bool { // decRef decrements the inode ref count and releases the inode resources if // the ref count hits 0. // -// Precondition: Must have locked fs.mu. -func (in *inode) decRef(fs *filesystem) { +// Precondition: Must have locked filesystem.mu. +func (in *inode) decRef() { if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { - delete(fs.inodeCache, in.inodeNum) + delete(in.fs.inodeCache, in.inodeNum) } else if refs < 0 { panic("ext.inode.decRef() called without holding a reference") } @@ -117,8 +116,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { // Build the inode based on its type. inode := inode{ + fs: fs, inodeNum: inodeNum, - dev: fs.dev, blkSize: blkSize, diskInode: diskInode, } @@ -154,11 +153,14 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v if err := in.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } + mnt := rp.Mount() switch in.impl.(type) { case *regularFile: var fd regularFileFD fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) return &fd.vfsfd, nil case *directory: // Can't open directories writably. This check is not necessary for a read @@ -167,8 +169,10 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) fd.flags = flags + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) return &fd.vfsfd, nil case *symlink: if flags&linux.O_PATH == 0 { @@ -177,7 +181,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v } var fd symlinkFD fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index ffc76ba5b..aec33e00a 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index e06548a98..bdf8705c1 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6 return 0, syserror.EBADF } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.EBADF } diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD new file mode 100644 index 000000000..52596c090 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -0,0 +1,60 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "slot_list", + out = "slot_list.go", + package = "kernfs", + prefix = "slot", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*slot", + "Linker": "*slot", + }, +) + +go_library( + name = "kernfs", + srcs = [ + "dynamic_bytes_file.go", + "fd_impl_util.go", + "filesystem.go", + "inode_impl_util.go", + "kernfs.go", + "slot_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/fspath", + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/context", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) + +go_test( + name = "kernfs_test", + size = "small", + srcs = ["kernfs_test.go"], + deps = [ + ":kernfs", + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go new file mode 100644 index 000000000..30c06baf0 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -0,0 +1,131 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// DynamicBytesFile implements kernfs.Inode and represents a read-only +// file whose contents are backed by a vfs.DynamicBytesSource. +// +// Must be initialized with Init before first use. +type DynamicBytesFile struct { + InodeAttrs + InodeNoopRefCount + InodeNotDirectory + InodeNotSymlink + + data vfs.DynamicBytesSource +} + +// Init intializes a dynamic bytes file. +func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) { + f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444) + f.data = data +} + +// Open implements Inode.Open. +func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &DynamicBytesFD{} + fd.Init(rp.Mount(), vfsd, f.data, flags) + return &fd.vfsfd, nil +} + +// SetStat implements Inode.SetStat. +func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} + +// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a +// DynamicBytesFile. +// +// Must be initialized with Init before first use. +type DynamicBytesFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DynamicBytesFileDescriptionImpl + + vfsfd vfs.FileDescription + inode Inode + flags uint32 +} + +// Init initializes a DynamicBytesFD. +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.flags = flags + fd.inode = d.Impl().(*Dentry).inode + fd.SetDataSource(data) + fd.vfsfd.Init(fd, m, d) +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *DynamicBytesFD) Release() {} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} + +// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. +func (fd *DynamicBytesFD) StatusFlags(ctx context.Context) (uint32, error) { + return fd.flags, nil +} + +// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. +func (fd *DynamicBytesFD) SetStatusFlags(ctx context.Context, flags uint32) error { + // None of the flags settable by fcntl(F_SETFL) are supported, so this is a + // no-op. + return nil +} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go new file mode 100644 index 000000000..d6c18937a --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -0,0 +1,207 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory +// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not +// compatible with dynamic directories. +// +// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling +// IterDirents callback. The IterDirents callback therefore cannot hash or +// unhash children, or recursively call IterDirents on the same underlying +// inode. +// +// Must be initialize with Init before first use. +type GenericDirectoryFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DirectoryFileDescriptionDefaultImpl + + vfsfd vfs.FileDescription + children *OrderedChildren + flags uint32 + off int64 +} + +// Init initializes a GenericDirectoryFD. +func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.children = children + fd.flags = flags + fd.vfsfd.Init(fd, m, d) +} + +// VFSFileDescription returns a pointer to the vfs.FileDescription representing +// this object. +func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription { + return &fd.vfsfd +} + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDecriptionImpl.Release. +func (fd *GenericDirectoryFD) Release() {} + +func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *GenericDirectoryFD) inode() Inode { + return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode +} + +// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// o.mu when calling cb. +func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + vfsFS := fd.filesystem() + fs := vfsFS.Impl().(*Filesystem) + vfsd := fd.vfsfd.VirtualDentry().Dentry() + + fs.mu.Lock() + defer fs.mu.Unlock() + + // Handle ".". + if fd.off == 0 { + stat := fd.inode().Stat(vfsFS) + dirent := vfs.Dirent{ + Name: ".", + Type: linux.DT_DIR, + Ino: stat.Ino, + NextOff: 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle "..". + if fd.off == 1 { + parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode + stat := parentInode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: "..", + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: 2, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle static children. + fd.children.mu.RLock() + defer fd.children.mu.RUnlock() + // fd.off accounts for "." and "..", but fd.children do not track + // these. + childIdx := fd.off - 2 + for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { + inode := it.Dentry.Impl().(*Dentry).inode + stat := inode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: it.Name, + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: fd.off + 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + return nil +} + +// Seek implements vfs.FileDecriptionImpl.Seek. +func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fs := fd.filesystem().Impl().(*Filesystem) + fs.mu.Lock() + defer fs.mu.Unlock() + + switch whence { + case linux.SEEK_SET: + // Use offset as given. + case linux.SEEK_CUR: + offset += fd.off + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. +func (fd *GenericDirectoryFD) StatusFlags(ctx context.Context) (uint32, error) { + return fd.flags, nil +} + +// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. +func (fd *GenericDirectoryFD) SetStatusFlags(ctx context.Context, flags uint32) error { + // None of the flags settable by fcntl(F_SETFL) are supported, so this is a + // no-op. + return nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return inode.SetStat(fs, opts) +} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go new file mode 100644 index 000000000..3cbbe4b20 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -0,0 +1,743 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements vfs.FilesystemImpl for kernfs. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// stepExistingLocked resolves rp.Component() in parent directory vfsd. +// +// stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postcondition: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + // Directory searchable? + if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } +afterSymlink: + d.dirMu.Lock() + nextVFSD, err := rp.ResolveComponent(vfsd) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + if nextVFSD != nil { + // Cached dentry exists, revalidate. + next := nextVFSD.Impl().(*Dentry) + if !next.inode.Valid(ctx) { + d.dirMu.Lock() + rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD) + d.dirMu.Unlock() + fs.deferDecRef(nextVFSD) // Reference from Lookup. + nextVFSD = nil + } + } + if nextVFSD == nil { + // Dentry isn't cached; it either doesn't exist or failed + // revalidation. Attempt to resolve it via Lookup. + name := rp.Component() + var err error + nextVFSD, err = d.inode.Lookup(ctx, name) + // Reference on nextVFSD dropped by a corresponding Valid. + if err != nil { + return nil, err + } + d.InsertChild(name, nextVFSD) + } + next := nextVFSD.Impl().(*Dentry) + + // Resolve any symlink at current path component. + if rp.ShouldFollowSymlink() && d.isSymlink() { + // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". + target, err := next.inode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + goto afterSymlink + + } + rp.Advance() + return nextVFSD, nil +} + +// walkExistingLocked resolves rp to an existing file. +// +// walkExistingLocked is loosely analogous to Linux's +// fs/namei.c:path_lookupat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Done() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if rp.MustBeDir() && !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// walkParentDirLocked resolves all but the last path component of rp to an +// existing directory. It does not check that the returned directory is +// searchable by the provider of rp. +// +// walkParentDirLocked is loosely analogous to Linux's +// fs/namei.c:path_parentat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Final() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// checkCreateLocked checks that a file named rp.Component() may be created in +// directory parentVFSD, then returns rp.Component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. parentInode +// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. +func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return "", err + } + pc := rp.Component() + if pc == "." || pc == ".." { + return "", syserror.EEXIST + } + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return "", err + } + if childVFSD != nil { + return "", syserror.EEXIST + } + if parentVFSD.IsDisowned() { + return "", syserror.ENOENT + } + return pc, nil +} + +// checkDeleteLocked checks that the file represented by vfsd may be deleted. +// +// Preconditions: Filesystem.mu must be locked for at least reading. +func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { + parentVFSD := vfsd.Parent() + if parentVFSD == nil { + return syserror.EBUSY + } + if parentVFSD.IsDisowned() { + return syserror.ENOENT + } + if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + return nil +} + +// checkRenameLocked checks that a rename operation may be performed on the +// target dentry across the given set of parent directories. The target dentry +// may be nil. +// +// Precondition: isDir(dstInode) == true. +func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error { + srcDir := src.Parent() + if srcDir == nil { + return syserror.EBUSY + } + if srcDir.IsDisowned() { + return syserror.ENOENT + } + if dstDir.IsDisowned() { + return syserror.ENOENT + } + // Check for creation permissions on dst dir. + if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + + return nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *Filesystem) Release() { +} + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *Filesystem) Sync(ctx context.Context) error { + // All filesystem state is in-memory. + return nil +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + + if opts.CheckSearchable { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + } + vfsd.IncRef() // Ownership transferred to caller. + return vfsd, nil +} + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if rp.Mount() != vd.Mount() { + return syserror.EXDEV + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + + d := vd.Dentry().Impl().(*Dentry) + if d.isDir() { + return syserror.EPERM + } + + child, err := parentInode.NewLink(ctx, pc, d.inode) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewDir(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + new, err := parentInode.NewNode(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, new) + return nil +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // Filter out flags that are not supported by kernfs. O_DIRECTORY and + // O_NOFOLLOW have no effect here (they're handled by VFS by setting + // appropriate bits in rp), but are returned by + // FileDescriptionImpl.StatusFlags(). + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + ats := vfs.AccessTypesForOpenFlags(opts.Flags) + + // Do not create new file. + if opts.Flags&linux.O_CREAT == 0 { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } + + // May create new file. + mustCreate := opts.Flags&linux.O_EXCL != 0 + vfsd := rp.Start() + inode := vfsd.Impl().(*Dentry).inode + fs.mu.Lock() + defer fs.mu.Unlock() + if rp.Done() { + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } +afterTrailingSymlink: + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return nil, err + } + // Check for search permission in the parent directory. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + // Reject attempts to open directories with O_CREAT. + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + pc := rp.Component() + if pc == "." || pc == ".." { + return nil, syserror.EISDIR + } + // Determine whether or not we need to create a file. + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return nil, err + } + if childVFSD == nil { + // Already checked for searchability above; now check for writability. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + defer rp.Mount().EndWrite() + // Create and open the child. + child, err := parentInode.NewFile(ctx, pc, opts) + if err != nil { + return nil, err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags) + } + // Open existing file or follow symlink. + if mustCreate { + return nil, syserror.EEXIST + } + childDentry := childVFSD.Impl().(*Dentry) + childInode := childDentry.inode + if rp.ShouldFollowSymlink() { + if childDentry.isSymlink() { + target, err := childInode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + // rp.Final() may no longer be true since we now need to resolve the + // symlink target. + goto afterTrailingSymlink + } + } + if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return childInode.Open(rp, childVFSD, opts.Flags) +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + fs.mu.RLock() + d, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + if !d.Impl().(*Dentry).isSymlink() { + return "", syserror.EINVAL + } + return inode.Readlink(ctx) +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + exchange := opts.Flags&linux.RENAME_EXCHANGE != 0 + whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0 + if exchange && (noReplace || whiteout) { + // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE. + return syserror.EINVAL + } + if exchange || whiteout { + // Exchange and Whiteout flags are not supported on kernfs. + return syserror.EINVAL + } + + fs.mu.Lock() + defer fs.mu.Lock() + + mnt := rp.Mount() + if mnt != vd.Mount() { + return syserror.EXDEV + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + + dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + + srcVFSD := vd.Dentry() + srcDirVFSD := srcVFSD.Parent() + + // Can we remove the src dentry? + if err := checkDeleteLocked(rp, srcVFSD); err != nil { + return err + } + + // Can we create the dst dentry? + var dstVFSD *vfs.Dentry + pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode) + switch err { + case nil: + // Ok, continue with rename as replacement. + case syserror.EEXIST: + if noReplace { + // Won't overwrite existing node since RENAME_NOREPLACE was requested. + return syserror.EEXIST + } + dstVFSD, err = rp.ResolveChild(dstDirVFSD, pc) + if err != nil { + panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) + } + default: + return err + } + + mntns := vfs.MountNamespaceFromContext(ctx) + virtfs := rp.VirtualFilesystem() + + srcDirDentry := srcDirVFSD.Impl().(*Dentry) + dstDirDentry := dstDirVFSD.Impl().(*Dentry) + + // We can't deadlock here due to lock ordering because we're protected from + // concurrent renames by fs.mu held for writing. + srcDirDentry.dirMu.Lock() + defer srcDirDentry.dirMu.Unlock() + dstDirDentry.dirMu.Lock() + defer dstDirDentry.dirMu.Unlock() + + if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { + return err + } + srcDirInode := srcDirDentry.inode + replaced, err := srcDirInode.Rename(ctx, srcVFSD.Name(), pc, srcVFSD, dstDirVFSD) + if err != nil { + virtfs.AbortRenameDentry(srcVFSD, dstVFSD) + return err + } + virtfs.CommitRenameReplaceDentry(srcVFSD, dstDirVFSD, pc, replaced) + return nil +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if !vfsd.Impl().(*Dentry).isDir() { + return syserror.ENOTDIR + } + if inode.HasChildren() { + return syserror.ENOTEMPTY + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + if opts.Stat.Mask == 0 { + return nil + } + return inode.SetStat(fs.VFSFilesystem(), opts) +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statx{}, err + } + return inode.Stat(fs.VFSFilesystem()), nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statfs{}, err + } + // TODO: actually implement statfs + return linux.Statfs{}, syserror.ENOSYS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewSymlink(ctx, pc, target) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, _, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if vfsd.Impl().(*Dentry).isDir() { + return syserror.EISDIR + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return nil, err + } + // kernfs currently does not support extended attributes. + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + // kernfs currently does not support extended attributes. + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go new file mode 100644 index 000000000..7b45b702a --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -0,0 +1,492 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// InodeNoopRefCount partially implements the Inode interface, specifically the +// inodeRefs sub interface. InodeNoopRefCount implements a simple reference +// count for inodes, performing no extra actions when references are obtained or +// released. This is suitable for simple file inodes that don't reference any +// resources. +type InodeNoopRefCount struct { +} + +// IncRef implements Inode.IncRef. +func (n *InodeNoopRefCount) IncRef() { +} + +// DecRef implements Inode.DecRef. +func (n *InodeNoopRefCount) DecRef() { +} + +// TryIncRef implements Inode.TryIncRef. +func (n *InodeNoopRefCount) TryIncRef() bool { + return true +} + +// Destroy implements Inode.Destroy. +func (n *InodeNoopRefCount) Destroy() { +} + +// InodeDirectoryNoNewChildren partially implements the Inode interface. +// InodeDirectoryNoNewChildren represents a directory inode which does not +// support creation of new children. +type InodeDirectoryNoNewChildren struct{} + +// NewFile implements Inode.NewFile. +func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewLink implements Inode.NewLink. +func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// InodeNotDirectory partially implements the Inode interface, specifically the +// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not +// represent directories can embed this to provide no-op implementations for +// directory-related functions. +type InodeNotDirectory struct { +} + +// HasChildren implements Inode.HasChildren. +func (*InodeNotDirectory) HasChildren() bool { + return false +} + +// NewFile implements Inode.NewFile. +func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + panic("NewFile called on non-directory inode") +} + +// NewDir implements Inode.NewDir. +func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + panic("NewDir called on non-directory inode") +} + +// NewLink implements Inode.NewLinkink. +func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + panic("NewLink called on non-directory inode") +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + panic("NewSymlink called on non-directory inode") +} + +// NewNode implements Inode.NewNode. +func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + panic("NewNode called on non-directory inode") +} + +// Unlink implements Inode.Unlink. +func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { + panic("Unlink called on non-directory inode") +} + +// RmDir implements Inode.RmDir. +func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { + panic("RmDir called on non-directory inode") +} + +// Rename implements Inode.Rename. +func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { + panic("Rename called on non-directory inode") +} + +// Lookup implements Inode.Lookup. +func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + panic("Lookup called on non-directory inode") +} + +// Valid implements Inode.Valid. +func (*InodeNotDirectory) Valid(context.Context) bool { + return true +} + +// InodeNoDynamicLookup partially implements the Inode interface, specifically +// the inodeDynamicLookup sub interface. Directory inodes that do not support +// dymanic entries (i.e. entries that are not "hashed" into the +// vfs.Dentry.children) can embed this to provide no-op implementations for +// functions related to dynamic entries. +type InodeNoDynamicLookup struct{} + +// Lookup implements Inode.Lookup. +func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + return nil, syserror.ENOENT +} + +// Valid implements Inode.Valid. +func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { + return true +} + +// InodeNotSymlink partially implements the Inode interface, specifically the +// inodeSymlink sub interface. All inodes that are not symlinks may embed this +// to return the appropriate errors from symlink-related functions. +type InodeNotSymlink struct{} + +// Readlink implements Inode.Readlink. +func (*InodeNotSymlink) Readlink(context.Context) (string, error) { + return "", syserror.EINVAL +} + +// InodeAttrs partially implements the Inode interface, specifically the +// inodeMetadata sub interface. InodeAttrs provides functionality related to +// inode attributes. +// +// Must be initialized by Init prior to first use. +type InodeAttrs struct { + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 +} + +// Init initializes this InodeAttrs. +func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) { + if mode.FileType() == 0 { + panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) + } + + nlink := uint32(1) + if mode.FileType() == linux.ModeDirectory { + nlink = 2 + } + atomic.StoreUint64(&a.ino, ino) + atomic.StoreUint32(&a.mode, uint32(mode)) + atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) + atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) + atomic.StoreUint32(&a.nlink, nlink) +} + +// Mode implements Inode.Mode. +func (a *InodeAttrs) Mode() linux.FileMode { + return linux.FileMode(atomic.LoadUint32(&a.mode)) +} + +// Stat partially implements Inode.Stat. Note that this function doesn't provide +// all the stat fields, and the embedder should consider extending the result +// with filesystem-specific fields. +func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx { + var stat linux.Statx + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Ino = atomic.LoadUint64(&a.ino) + stat.Mode = uint16(a.Mode()) + stat.UID = atomic.LoadUint32(&a.uid) + stat.GID = atomic.LoadUint32(&a.gid) + stat.Nlink = atomic.LoadUint32(&a.nlink) + + // TODO: Implement other stat fields like timestamps. + + return stat +} + +// SetStat implements Inode.SetStat. +func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { + stat := opts.Stat + if stat.Mask&linux.STATX_MODE != 0 { + for { + old := atomic.LoadUint32(&a.mode) + new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT)) + if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped { + break + } + } + } + + if stat.Mask&linux.STATX_UID != 0 { + atomic.StoreUint32(&a.uid, stat.UID) + } + if stat.Mask&linux.STATX_GID != 0 { + atomic.StoreUint32(&a.gid, stat.GID) + } + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. + + // TODO: Implement other stat fields like timestamps. + + return nil +} + +// CheckPermissions implements Inode.CheckPermissions. +func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + mode := a.Mode() + return vfs.GenericCheckPermissions( + creds, + ats, + mode.FileType() == linux.ModeDirectory, + uint16(mode), + auth.KUID(atomic.LoadUint32(&a.uid)), + auth.KGID(atomic.LoadUint32(&a.gid)), + ) +} + +// IncLinks implements Inode.IncLinks. +func (a *InodeAttrs) IncLinks(n uint32) { + if atomic.AddUint32(&a.nlink, n) <= n { + panic("InodeLink.IncLinks called with no existing links") + } +} + +// DecLinks implements Inode.DecLinks. +func (a *InodeAttrs) DecLinks() { + if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) { + // Negative overflow + panic("Inode.DecLinks called at 0 links") + } +} + +type slot struct { + Name string + Dentry *vfs.Dentry + slotEntry +} + +// OrderedChildrenOptions contains initialization options for OrderedChildren. +type OrderedChildrenOptions struct { + // Writable indicates whether vfs.FilesystemImpl methods implemented by + // OrderedChildren may modify the tracked children. This applies to + // operations related to rename, unlink and rmdir. If an OrderedChildren is + // not writable, these operations all fail with EPERM. + Writable bool +} + +// OrderedChildren partially implements the Inode interface. OrderedChildren can +// be embedded in directory inodes to keep track of the children in the +// directory, and can then be used to implement a generic directory FD -- see +// GenericDirectoryFD. OrderedChildren is not compatible with dynamic +// directories. +// +// Must be initialize with Init before first use. +type OrderedChildren struct { + refs.AtomicRefCount + + // Can children be modified by user syscalls? It set to false, interface + // methods that would modify the children return EPERM. Immutable. + writable bool + + mu sync.RWMutex + order slotList + set map[string]*slot +} + +// Init initializes an OrderedChildren. +func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { + o.writable = opts.Writable + o.set = make(map[string]*slot) +} + +// DecRef implements Inode.DecRef. +func (o *OrderedChildren) DecRef() { + o.AtomicRefCount.DecRefWithDestructor(o.Destroy) +} + +// Destroy cleans up resources referenced by this OrderedChildren. +func (o *OrderedChildren) Destroy() { + o.mu.Lock() + defer o.mu.Unlock() + o.order.Reset() + o.set = nil +} + +// Populate inserts children into this OrderedChildren, and d's dentry +// cache. Populate returns the number of directories inserted, which the caller +// may use to update the link count for the parent directory. +// +// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory +// inode. children must not contain any conflicting entries already in o. +func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 { + var links uint32 + for name, child := range children { + if child.isDir() { + links++ + } + if err := o.Insert(name, child.VFSDentry()); err != nil { + panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) + } + d.InsertChild(name, child.VFSDentry()) + } + return links +} + +// HasChildren implements Inode.HasChildren. +func (o *OrderedChildren) HasChildren() bool { + o.mu.RLock() + defer o.mu.RUnlock() + return len(o.set) > 0 +} + +// Insert inserts child into o. This ignores the writability of o, as this is +// not part of the vfs.FilesystemImpl interface, and is a lower-level operation. +func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { + o.mu.Lock() + defer o.mu.Unlock() + if _, ok := o.set[name]; ok { + return syserror.EEXIST + } + s := &slot{ + Name: name, + Dentry: child, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) removeLocked(name string) { + if s, ok := o.set[name]; ok { + delete(o.set, name) + o.order.Remove(s) + } +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { + if s, ok := o.set[name]; ok { + // Existing slot with given name, simply replace the dentry. + var old *vfs.Dentry + old, s.Dentry = s.Dentry, new + return old + } + + // No existing slot with given name, create and hash new slot. + s := &slot{ + Name: name, + Dentry: new, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for reading or writing. +func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { + s, ok := o.set[name] + if !ok { + return syserror.ENOENT + } + if s.Dentry != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child)) + } + return nil +} + +// Unlink implements Inode.Unlink. +func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { + if !o.writable { + return syserror.EPERM + } + o.mu.Lock() + defer o.mu.Unlock() + if err := o.checkExistingLocked(name, child); err != nil { + return err + } + o.removeLocked(name) + return nil +} + +// Rmdir implements Inode.Rmdir. +func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { + // We're not responsible for checking that child is a directory, that it's + // empty, or updating any link counts; so this is the same as unlink. + return o.Unlink(ctx, name, child) +} + +type renameAcrossDifferentImplementationsError struct{} + +func (renameAcrossDifferentImplementationsError) Error() string { + return "rename across inodes with different implementations" +} + +// Rename implements Inode.Rename. +// +// Precondition: Rename may only be called across two directory inodes with +// identical implementations of Rename. Practically, this means filesystems that +// implement Rename by embedding OrderedChildren for any directory +// implementation must use OrderedChildren for all directory implementations +// that will support Rename. +// +// Postcondition: reference on any replaced dentry transferred to caller. +func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { + dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) + if !ok { + return nil, renameAcrossDifferentImplementationsError{} + } + if !o.writable || !dst.writable { + return nil, syserror.EPERM + } + // Note: There's a potential deadlock below if concurrent calls to Rename + // refer to the same src and dst directories in reverse. We avoid any + // ordering issues because the caller is required to serialize concurrent + // calls to Rename in accordance with the interface declaration. + o.mu.Lock() + defer o.mu.Unlock() + if dst != o { + dst.mu.Lock() + defer dst.mu.Unlock() + } + if err := o.checkExistingLocked(oldname, child); err != nil { + return nil, err + } + replaced := dst.replaceChildLocked(newname, child) + return replaced, nil +} + +// nthLocked returns an iterator to the nth child tracked by this object. The +// iterator is valid until the caller releases o.mu. Returns nil if the +// requested index falls out of bounds. +// +// Preconditon: Caller must hold o.mu for reading. +func (o *OrderedChildren) nthLocked(i int64) *slot { + for it := o.order.Front(); it != nil && i >= 0; it = it.Next() { + if i == 0 { + return it + } + i-- + } + return nil +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go new file mode 100644 index 000000000..bb01c3d01 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -0,0 +1,405 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package kernfs provides the tools to implement inode-based filesystems. +// Kernfs has two main features: +// +// 1. The Inode interface, which maps VFS2's path-based filesystem operations to +// specific filesystem nodes. Kernfs uses the Inode interface to provide a +// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as +// the synchronization mechanism for all filesystem operations by holding a +// filesystem-wide lock across all operations. +// +// 2. Various utility types which provide generic implementations for various +// parts of the Inode and vfs.FileDescription interfaces. Client filesystems +// based on kernfs can embed the appropriate set of these to avoid having to +// reimplement common filesystem operations. See inode_impl_util.go and +// fd_impl_util.go. +// +// Reference Model: +// +// Kernfs dentries represents named pointers to inodes. Dentries and inode have +// independent lifetimes and reference counts. A child dentry unconditionally +// holds a reference on its parent directory's dentry. A dentry also holds a +// reference on the inode it points to. Multiple dentries can point to the same +// inode (for example, in the case of hardlinks). File descriptors hold a +// reference to the dentry they're opened on. +// +// Dentries are guaranteed to exist while holding Filesystem.mu for +// reading. Dropping dentries require holding Filesystem.mu for writing. To +// queue dentries for destruction from a read critical section, see +// Filesystem.deferDecRef. +// +// Lock ordering: +// +// kernfs.Filesystem.mu +// kernfs.Dentry.dirMu +// vfs.VirtualFilesystem.mountMu +// vfs.Dentry.mu +// kernfs.Filesystem.droppedDentriesMu +// (inode implementation locks, if any) +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory +// filesystem. Concrete implementations are expected to embed this in their own +// Filesystem type. +type Filesystem struct { + vfsfs vfs.Filesystem + + droppedDentriesMu sync.Mutex + + // droppedDentries is a list of dentries waiting to be DecRef()ed. This is + // used to defer dentry destruction until mu can be acquired for + // writing. Protected by droppedDentriesMu. + droppedDentries []*vfs.Dentry + + // mu synchronizes the lifetime of Dentries on this filesystem. Holding it + // for reading guarantees continued existence of any resolved dentries, but + // the dentry tree may be modified. + // + // Kernfs dentries can only be DecRef()ed while holding mu for writing. For + // example: + // + // fs.mu.Lock() + // defer fs.mu.Unlock() + // ... + // dentry1.DecRef() + // defer dentry2.DecRef() // Ok, will run before Unlock. + // + // If discarding dentries in a read context, use Filesystem.deferDecRef. For + // example: + // + // fs.mu.RLock() + // fs.mu.processDeferredDecRefs() + // defer fs.mu.RUnlock() + // ... + // fs.deferDecRef(dentry) + mu sync.RWMutex + + // nextInoMinusOne is used to to allocate inode numbers on this + // filesystem. Must be accessed by atomic operations. + nextInoMinusOne uint64 +} + +// deferDecRef defers dropping a dentry ref until the next call to +// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. +// +// Precondition: d must not already be pending destruction. +func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { + fs.droppedDentriesMu.Lock() + fs.droppedDentries = append(fs.droppedDentries, d) + fs.droppedDentriesMu.Unlock() +} + +// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the +// droppedDentries list. See comment on Filesystem.mu. +func (fs *Filesystem) processDeferredDecRefs() { + fs.mu.Lock() + fs.processDeferredDecRefsLocked() + fs.mu.Unlock() +} + +// Precondition: fs.mu must be held for writing. +func (fs *Filesystem) processDeferredDecRefsLocked() { + fs.droppedDentriesMu.Lock() + for _, d := range fs.droppedDentries { + d.DecRef() + } + fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. + fs.droppedDentriesMu.Unlock() +} + +// Init initializes a kernfs filesystem. This should be called from during +// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding +// kernfs. +func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) { + fs.vfsfs.Init(vfsObj, fs) +} + +// VFSFilesystem returns the generic vfs filesystem object. +func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem { + return &fs.vfsfs +} + +// NextIno allocates a new inode number on this filesystem. +func (fs *Filesystem) NextIno() uint64 { + return atomic.AddUint64(&fs.nextInoMinusOne, 1) +} + +// These consts are used in the Dentry.flags field. +const ( + // Dentry points to a directory inode. + dflagsIsDir = 1 << iota + + // Dentry points to a symlink inode. + dflagsIsSymlink +) + +// Dentry implements vfs.DentryImpl. +// +// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a +// named reference to an inode. A dentry generally lives as long as it's part of +// a mounted filesystem tree. Kernfs doesn't cache dentries once all references +// to them are removed. Dentries hold a single reference to the inode they point +// to, and child dentries hold a reference on their parent. +// +// Must be initialized by Init prior to first use. +type Dentry struct { + refs.AtomicRefCount + + vfsd vfs.Dentry + inode Inode + + refs uint64 + + // flags caches useful information about the dentry from the inode. See the + // dflags* consts above. Must be accessed by atomic ops. + flags uint32 + + // dirMu protects vfsd.children for directory dentries. + dirMu sync.Mutex +} + +// Init initializes this dentry. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) Init(inode Inode) { + d.vfsd.Init(d) + d.inode = inode + ftype := inode.Mode().FileType() + if ftype == linux.ModeDirectory { + d.flags |= dflagsIsDir + } + if ftype == linux.ModeSymlink { + d.flags |= dflagsIsSymlink + } +} + +// VFSDentry returns the generic vfs dentry for this kernfs dentry. +func (d *Dentry) VFSDentry() *vfs.Dentry { + return &d.vfsd +} + +// isDir checks whether the dentry points to a directory inode. +func (d *Dentry) isDir() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0 +} + +// isSymlink checks whether the dentry points to a symlink inode. +func (d *Dentry) isSymlink() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef() { + d.AtomicRefCount.DecRefWithDestructor(d.destroy) +} + +// Precondition: Dentry must be removed from VFS' dentry cache. +func (d *Dentry) destroy() { + d.inode.DecRef() // IncRef from Init. + d.inode = nil + if parent := d.vfsd.Parent(); parent != nil { + parent.DecRef() // IncRef from Dentry.InsertChild. + } +} + +// InsertChild inserts child into the vfs dentry cache with the given name under +// this dentry. This does not update the directory inode, so calling this on +// it's own isn't sufficient to insert a child into a directory. InsertChild +// updates the link count on d if required. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) InsertChild(name string, child *vfs.Dentry) { + if !d.isDir() { + panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) + } + vfsDentry := d.VFSDentry() + vfsDentry.IncRef() // DecRef in child's Dentry.destroy. + d.dirMu.Lock() + vfsDentry.InsertChild(child, name) + d.dirMu.Unlock() +} + +// The Inode interface maps filesystem-level operations that operate on paths to +// equivalent operations on specific filesystem nodes. +// +// The interface methods are groups into logical categories as sub interfaces +// below. Generally, an implementation for each sub interface can be provided by +// embedding an appropriate type from inode_impl_utils.go. The sub interfaces +// are purely organizational. Methods declared directly in the main interface +// have no generic implementations, and should be explicitly provided by the +// client filesystem. +// +// Generally, implementations are not responsible for tasks that are common to +// all filesystems. These include: +// +// - Checking that dentries passed to methods are of the appropriate file type. +// - Checking permissions. +// - Updating link and reference counts. +// +// Specific responsibilities of implementations are documented below. +type Inode interface { + // Methods related to reference counting. A generic implementation is + // provided by InodeNoopRefCount. These methods are generally called by the + // equivalent Dentry methods. + inodeRefs + + // Methods related to node metadata. A generic implementation is provided by + // InodeAttrs. + inodeMetadata + + // Method for inodes that represent symlink. InodeNotSymlink provides a + // blanket implementation for all non-symlink inodes. + inodeSymlink + + // Method for inodes that represent directories. InodeNotDirectory provides + // a blanket implementation for all non-directory inodes. + inodeDirectory + + // Method for inodes that represent dynamic directories and their + // children. InodeNoDynamicLookup provides a blanket implementation for all + // non-dynamic-directory inodes. + inodeDynamicLookup + + // Open creates a file description for the filesystem object represented by + // this inode. The returned file description should hold a reference on the + // inode for its lifetime. + // + // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry. + Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) +} + +type inodeRefs interface { + IncRef() + DecRef() + TryIncRef() bool + // Destroy is called when the inode reaches zero references. Destroy release + // all resources (references) on objects referenced by the inode, including + // any child dentries. + Destroy() +} + +type inodeMetadata interface { + // CheckPermissions checks that creds may access this inode for the + // requested access type, per the the rules of + // fs/namei.c:generic_permission(). + CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error + + // Mode returns the (struct stat)::st_mode value for this inode. This is + // separated from Stat for performance. + Mode() linux.FileMode + + // Stat returns the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.StatAt. + Stat(fs *vfs.Filesystem) linux.Statx + + // SetStat updates the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.SetStatAt. + SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error +} + +// Precondition: All methods in this interface may only be called on directory +// inodes. +type inodeDirectory interface { + // The New{File,Dir,Node,Symlink} methods below should return a new inode + // hashed into this inode. + // + // These inode constructors are inode-level operations rather than + // filesystem-level operations to allow client filesystems to mix different + // implementations based on the new node's location in the + // filesystem. + + // HasChildren returns true if the directory inode has any children. + HasChildren() bool + + // NewFile creates a new regular file inode. + NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) + + // NewDir creates a new directory inode. + NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) + + // NewLink creates a new hardlink to a specified inode in this + // directory. Implementations should create a new kernfs Dentry pointing to + // target, and update target's link count. + NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) + + // NewSymlink creates a new symbolic link inode. + NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) + + // NewNode creates a new filesystem node for a mknod syscall. + NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) + + // Unlink removes a child dentry from this directory inode. + Unlink(ctx context.Context, name string, child *vfs.Dentry) error + + // RmDir removes an empty child directory from this directory + // inode. Implementations must update the parent directory's link count, + // if required. Implementations are not responsible for checking that child + // is a directory, checking for an empty directory. + RmDir(ctx context.Context, name string, child *vfs.Dentry) error + + // Rename is called on the source directory containing an inode being + // renamed. child should point to the resolved child in the source + // directory. If Rename replaces a dentry in the destination directory, it + // should return the replaced dentry or nil otherwise. + // + // Precondition: Caller must serialize concurrent calls to Rename. + Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) +} + +type inodeDynamicLookup interface { + // Lookup should return an appropriate dentry if name should resolve to a + // child of this dynamic directory inode. This gives the directory an + // opportunity on every lookup to resolve additional entries that aren't + // hashed into the directory. This is only called when the inode is a + // directory. If the inode is not a directory, or if the directory only + // contains a static set of children, the implementer can unconditionally + // return an appropriate error (ENOTDIR and ENOENT respectively). + // + // The child returned by Lookup will be hashed into the VFS dentry tree. Its + // lifetime can be controlled by the filesystem implementation with an + // appropriate implementation of Valid. + // + // Lookup returns the child with an extra reference and the caller owns this + // reference. + Lookup(ctx context.Context, name string) (*vfs.Dentry, error) + + // Valid should return true if this inode is still valid, or needs to + // be resolved again by a call to Lookup. + Valid(ctx context.Context) bool +} + +type inodeSymlink interface { + // Readlink resolves the target of a symbolic link. If an inode is not a + // symlink, the implementation should return EINVAL. + Readlink(ctx context.Context) (string, error) +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go new file mode 100644 index 000000000..f78bb7b04 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -0,0 +1,423 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs_test + +import ( + "bytes" + "fmt" + "io" + "runtime" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const defaultMode linux.FileMode = 01777 +const staticFileContent = "This is sample content for a static test file." + +// RootDentryFn is a generator function for creating the root dentry of a test +// filesystem. See newTestSystem. +type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry + +// TestSystem represents the context for a single test. +type TestSystem struct { + t *testing.T + ctx context.Context + creds *auth.Credentials + vfs *vfs.VirtualFilesystem + mns *vfs.MountNamespace + root vfs.VirtualDentry +} + +// newTestSystem sets up a minimal environment for running a test, including an +// instance of a test filesystem. Tests can control the contents of the +// filesystem by providing an appropriate rootFn, which should return a +// pre-populated root dentry. +func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + v := vfs.New() + v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("Failed to create testfs root mount: %v", err) + } + + s := &TestSystem{ + t: t, + ctx: ctx, + creds: creds, + vfs: v, + mns: mns, + root: mns.Root(), + } + runtime.SetFinalizer(s, func(s *TestSystem) { s.root.DecRef() }) + return s +} + +// PathOpAtRoot constructs a vfs.PathOperation for a path from the +// root of the test filesystem. +// +// Precondition: path should be relative path. +func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation { + return vfs.PathOperation{ + Root: s.root, + Start: s.root, + Pathname: path, + } +} + +// GetDentryOrDie attempts to resolve a dentry referred to by the +// provided path operation. If unsuccessful, the test fails. +func (s *TestSystem) GetDentryOrDie(pop vfs.PathOperation) vfs.VirtualDentry { + vd, err := s.vfs.GetDentryAt(s.ctx, s.creds, &pop, &vfs.GetDentryOptions{}) + if err != nil { + s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err) + } + return vd +} + +func (s *TestSystem) ReadToEnd(fd *vfs.FileDescription) (string, error) { + buf := make([]byte, usermem.PageSize) + bufIOSeq := usermem.BytesIOSequence(buf) + opts := vfs.ReadOptions{} + + var content bytes.Buffer + for { + n, err := fd.Impl().Read(s.ctx, bufIOSeq, opts) + if n == 0 || err != nil { + if err == io.EOF { + err = nil + } + return content.String(), err + } + content.Write(buf[:n]) + } +} + +type fsType struct { + rootFn RootDentryFn +} + +type filesystem struct { + kernfs.Filesystem +} + +type file struct { + kernfs.DynamicBytesFile + content string +} + +func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { + f := &file{} + f.content = content + f.DynamicBytesFile.Init(creds, fs.NextIno(), f) + + d := &kernfs.Dentry{} + d.Init(f) + return d +} + +func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%s", f.content) + return nil +} + +type attrs struct { + kernfs.InodeAttrs +} + +func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error { + return syserror.EPERM +} + +type readonlyDir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + kernfs.InodeDirectoryNoNewChildren + + kernfs.OrderedChildren + dentry kernfs.Dentry +} + +func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &readonlyDir{} + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +type dir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + + fs *filesystem + dentry kernfs.Dentry + kernfs.OrderedChildren +} + +func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &dir{} + dir.fs = fs + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + dir := d.fs.newDir(creds, opts.Mode, nil) + dirVFSD := dir.VFSDentry() + if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { + dir.DecRef() + return nil, err + } + d.IncLinks(1) + return dirVFSD, nil +} + +func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + f := d.fs.newFile(creds, "") + fVFSD := f.VFSDentry() + if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { + f.DecRef() + return nil, err + } + return fVFSD, nil +} + +func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + fs := &filesystem{} + fs.Init(vfsObj) + root := fst.rootFn(creds, fs) + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// -------------------- Remainder of the file are test cases -------------------- + +func TestBasic(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() +} + +func TestMkdirGetDentry(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/a new directory") + if err := sys.vfs.MkdirAt(sys.ctx, sys.creds, &pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { + t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) + } + sys.GetDentryOrDie(pop).DecRef() +} + +func TestReadStaticFile(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("file1") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + content, err := sys.ReadToEnd(fd) + if err != nil { + sys.t.Fatalf("Read failed: %v", err) + } + if diff := cmp.Diff(staticFileContent, content); diff != "" { + sys.t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) + } +} + +func TestCreateNewFileInStaticDir(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/newfile") + opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode} + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, opts) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err) + } + + // Close the file. The file should persist. + fd.DecRef() + + fd, err = sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) + } + fd.DecRef() +} + +// direntCollector provides an implementation for vfs.IterDirentsCallback for +// testing. It simply iterates to the end of a given directory FD and collects +// all dirents emitted by the callback. +type direntCollector struct { + mu sync.Mutex + dirents map[string]vfs.Dirent +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (d *direntCollector) Handle(dirent vfs.Dirent) bool { + d.mu.Lock() + if d.dirents == nil { + d.dirents = make(map[string]vfs.Dirent) + } + d.dirents[dirent.Name] = dirent + d.mu.Unlock() + return true +} + +// count returns the number of dirents currently in the collector. +func (d *direntCollector) count() int { + d.mu.Lock() + defer d.mu.Unlock() + return len(d.dirents) +} + +// contains checks whether the collector has a dirent with the given name and +// type. +func (d *direntCollector) contains(name string, typ uint8) error { + d.mu.Lock() + defer d.mu.Unlock() + dirent, ok := d.dirents[name] + if !ok { + return fmt.Errorf("No dirent named %q found", name) + } + if dirent.Type != typ { + return fmt.Errorf("Dirent named %q found, but was expecting type %d, got: %+v", name, typ, dirent) + } + return nil +} + +func TestDirFDReadWrite(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, nil) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + // Read/Write should fail for directory FDs. + if _, err := fd.Read(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Read for directory FD failed with unexpected error: %v", err) + } + if _, err := fd.Write(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Wrire for directory FD failed with unexpected error: %v", err) + } +} + +func TestDirFDIterDirents(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + // Fill root with nodes backed by various inode implementations. + "dir1": fs.newReadonlyDir(creds, 0755, nil), + "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir3": fs.newDir(creds, 0755, nil), + }), + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + collector := &direntCollector{} + if err := fd.IterDirents(sys.ctx, collector); err != nil { + sys.t.Fatalf("IterDirent failed: %v", err) + } + + // Root directory should contain ".", ".." and 3 children: + if collector.count() != 5 { + sys.t.Fatalf("IterDirent returned too many dirents") + } + for _, dirName := range []string{".", "..", "dir1", "dir2"} { + if err := collector.contains(dirName, linux.DT_DIR); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + } + if err := collector.contains("file1", linux.DT_REG); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + +} diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d2450e810..0cc751eb8 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -23,14 +24,19 @@ go_library( "directory.go", "filesystem.go", "memfs.go", + "named_pipe.go", "regular_file.go", "symlink.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs", deps = [ "//pkg/abi/linux", + "//pkg/amutex", + "//pkg/fspath", + "//pkg/sentry/arch", "//pkg/sentry/context", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", @@ -44,6 +50,7 @@ go_test( deps = [ ":memfs", "//pkg/abi/linux", + "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", @@ -53,3 +60,19 @@ go_test( "//pkg/syserror", ], ) + +go_test( + name = "memfs_test", + size = "small", + srcs = ["pipe_test.go"], + embed = [":memfs"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go index a94b17db6..4a7a94a52 100644 --- a/pkg/sentry/fsimpl/memfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go @@ -21,6 +21,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -160,6 +161,8 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -173,10 +176,11 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -186,7 +190,6 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { defer root.DecRef() vd := root vd.IncRef() - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -219,6 +222,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() + vd = vfs.VirtualDentry{} if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -243,6 +248,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -343,6 +350,8 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -356,10 +365,11 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -384,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { } defer mountPoint.DecRef() // Create and mount the submount. - if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.NewFilesystemOptions{}); err != nil { + if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) @@ -395,7 +405,6 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount root: %v", err) } - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -435,6 +444,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -459,6 +469,14 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } + +func init() { + // Turn off reference leak checking for a fair comparison between vfs1 and + // vfs2. + refs.SetLeakMode(refs.NoLeakChecking) +} diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c52dc781c..0bd82e480 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -32,7 +32,7 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root @@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 0 { if !cb.Handle(vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: vfsd.Impl().(*dentry).inode.ino, - Off: 0, + Name: ".", + Type: linux.DT_DIR, + Ino: vfsd.Impl().(*dentry).inode.ino, + NextOff: 1, }) { return nil } @@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 1 { parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode if !cb.Handle(vfs.Dirent{ - Name: "..", - Type: parentInode.direntType(), - Ino: parentInode.ino, - Off: 1, + Name: "..", + Type: parentInode.direntType(), + Ino: parentInode.ino, + NextOff: 2, }) { return nil } @@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Skip other directoryFD iterators. if child.inode != nil { if !cb.Handle(vfs.Dirent{ - Name: child.vfsd.Name(), - Type: child.inode.direntType(), - Ino: child.inode.ino, - Off: fd.off, + Name: child.vfsd.Name(), + Type: child.inode.direntType(), + Ino: child.inode.ino, + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go index f79e2d9c8..22f1e811f 100644 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ b/pkg/sentry/fsimpl/memfs/filesystem.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -159,7 +160,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op return nil, err } } - inode.incRef() // vfsd.IncRef(&fs.vfsfs) + inode.incRef() return vfsd, nil } @@ -233,7 +234,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if err != nil { return err } - _, err = checkCreateLocked(rp, parentVFSD, parentInode) + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) if err != nil { return err } @@ -241,8 +242,40 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - // TODO: actually implement mknod - return syserror.EPERM + + switch opts.Mode.FileType() { + case 0: + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + fallthrough + case linux.ModeRegular: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeNamedPipe: + child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) + parentVFSD.InsertChild(&child.vfsd, pc) + parentInode.impl.(*directory).childList.PushBack(child) + return nil + + case linux.ModeSocket: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeCharacterDevice: + fallthrough + case linux.ModeBlockDevice: + // TODO(b/72101894): We don't support creating block or character + // devices at the moment. + // + // When we start supporting block and character devices, we'll + // need to check for CAP_MKNOD here. + return syserror.EPERM + + default: + // "EINVAL - mode requested creation of something other than a + // regular file, device special file, FIFO or socket." - mknod(2) + return syserror.EINVAL + } } // OpenAt implements vfs.FilesystemImpl.OpenAt. @@ -250,8 +283,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Filter out flags that are not supported by memfs. O_DIRECTORY and // O_NOFOLLOW have no effect here (they're handled by VFS by setting // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by + // pipes. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() @@ -260,7 +294,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return nil, err } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } mustCreate := opts.Flags&linux.O_EXCL != 0 @@ -275,7 +309,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } afterTrailingSymlink: // Walk to the parent directory of the last path component. @@ -320,7 +354,7 @@ afterTrailingSymlink: child := fs.newDentry(childInode) vfsd.InsertChild(&child.vfsd, pc) inode.impl.(*directory).childList.PushBack(child) - return childInode.open(rp, &child.vfsd, opts.Flags, true) + return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true) } // Open existing file or follow symlink. if mustCreate { @@ -336,16 +370,17 @@ afterTrailingSymlink: // symlink target. goto afterTrailingSymlink } - return childInode.open(rp, childVFSD, opts.Flags, false) + return childInode.open(ctx, rp, childVFSD, opts.Flags, false) } -func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(flags) if !afterCreate { if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { return nil, err } } + mnt := rp.Mount() switch impl := i.impl.(type) { case *regularFile: var fd regularFileFD @@ -353,12 +388,14 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte fd.readable = vfs.MayReadFileWithOpenFlags(flags) fd.writable = vfs.MayWriteFileWithOpenFlags(flags) if fd.writable { - if err := rp.Mount().CheckBeginWrite(); err != nil { + if err := mnt.CheckBeginWrite(); err != nil { return nil, err } - // Mount.EndWrite() is called by regularFileFD.Release(). + // mnt.EndWrite() is called by regularFileFD.Release(). } - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) if flags&linux.O_TRUNC != 0 { impl.mu.Lock() impl.data = impl.data[:0] @@ -372,12 +409,16 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) fd.flags = flags return &fd.vfsfd, nil case *symlink: // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP + case *namedPipe: + return newNamedPipeFD(ctx, impl, rp, vfsd, flags) default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -542,3 +583,58 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error inode.decLinksLocked() return nil } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return nil, err + } + // TODO(b/127675828): support extended attributes + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return "", err + } + // TODO(b/127675828): support extended attributes + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index 45cd42b3e..4cb2a4e0f 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -52,10 +52,10 @@ type filesystem struct { nextInoMinusOne uint64 // accessed using atomic memory operations } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { var fs filesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) root := fs.newDentry(fs.newDirectory(creds, 01777)) return &fs.vfsfs, &root.vfsd, nil } @@ -99,17 +99,17 @@ func (fs *filesystem) newDentry(inode *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { +func (d *dentry) DecRef() { d.inode.decRef() } @@ -137,7 +137,7 @@ type inode struct { impl interface{} // immutable } -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) @@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS stat.Size = uint64(len(impl.target)) stat.Blocks = allocatedBlocksForSize(stat.Size) + case *namedPipe: + stat.Mode |= linux.S_IFIFO default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -264,11 +266,11 @@ type fileDescription struct { } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode + return fd.vfsfd.Dentry().Impl().(*dentry).inode } // StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go new file mode 100644 index 000000000..91cb4b1fc --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/named_pipe.go @@ -0,0 +1,62 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +type namedPipe struct { + inode inode + + pipe *pipe.VFSPipe +} + +// Preconditions: +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. +func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { + file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // Only the parent has a link. + return &file.inode +} + +// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented +// entirely via struct embedding. +type namedPipeFD struct { + fileDescription + + *pipe.VFSPipeFD +} + +func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + var err error + var fd namedPipeFD + fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags) + if err != nil { + return nil, err + } + mnt := rp.Mount() + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, mnt, vfsd) + return &fd.vfsfd, nil +} diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go new file mode 100644 index 000000000..5bf527c80 --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/pipe_test.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const fileName = "mypipe" + +func TestSeparateFDs(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side. This is done in a concurrently because opening + // One end the pipe blocks until the other end is opened. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + rfdchan := make(chan *vfs.FileDescription) + go func() { + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} + rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + rfdchan <- rfd + }() + + // Open the write side. + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + rfd, ok := <-rfdchan + if !ok { + t.Fatalf("failed to open pipe for reading %q", fileName) + } + defer rfd.DecRef() + + const msg = "vamos azul" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingRead(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side as nonblocking. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} + rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) + } + defer rfd.DecRef() + + // Open the write side. + openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + const msg = "geh blau" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingWriteError(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the write side as nonblocking, which should return ENXIO. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} + _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != syserror.ENXIO { + t.Fatalf("expected ENXIO, but got error: %v", err) + } +} + +func TestSingleFD(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the pipe as readable and writable. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} + fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer fd.DecRef() + + const msg = "forza blu" + checkEmpty(ctx, t, fd) + checkWrite(ctx, t, fd, msg) + checkRead(ctx, t, fd, msg) +} + +// setup creates a VFS with a pipe in the root directory at path fileName. The +// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal +// upon failure. +func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("failed to create tmpfs root mount: %v", err) + } + + // Create the pipe. + root := mntns.Root() + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} + if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { + t.Fatalf("failed to create file %q: %v", fileName, err) + } + + // Sanity check: the file pipe exists and has the correct mode. + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat(%q) failed: %v", fileName, err) + } + if stat.Mode&^linux.S_IFMT != 0644 { + t.Errorf("got wrong permissions (%0o)", stat.Mode) + } + if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { + t.Errorf("got wrong file type (%0o)", stat.Mode) + } + + return ctx, creds, vfsObj, root +} + +// checkEmpty calls t.Fatal if the pipe in fd is not empty. +func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + readData := make([]byte, 1) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != syserror.ErrWouldBlock { + t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) + } + if bytesRead != 0 { + t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) + } +} + +// checkWrite calls t.Fatal if it fails to write all of msg to fd. +func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + writeData := []byte(msg) + src := usermem.BytesIOSequence(writeData) + bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("error writing to pipe %q: %v", fileName, err) + } + if bytesWritten != int64(len(writeData)) { + t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) + } +} + +// checkRead calls t.Fatal if it fails to read msg from fd. +func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + readData := make([]byte, len(msg)) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != nil { + t.Fatalf("error reading from pipe %q: %v", fileName, err) + } + if bytesRead != int64(len(msg)) { + t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) + } + if !bytes.Equal(readData, []byte(msg)) { + t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) + } +} diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 55f869798..b7f4853b3 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -37,7 +37,7 @@ type regularFile struct { dataLen int64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3d8a4deaf..ade6ac946 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go index c36c4aff5..0e016bca5 100644 --- a/pkg/sentry/fsimpl/proc/filesystems.go +++ b/pkg/sentry/fsimpl/proc/filesystems.go @@ -19,7 +19,7 @@ package proc // +stateify savable type filesystemsData struct{} -// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for +// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for // filesystemsData. We would need to retrive filesystem names from // vfs.VirtualFilesystem. Also needs vfs replacement for // fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go index e81b1e910..8683cf677 100644 --- a/pkg/sentry/fsimpl/proc/mounts.go +++ b/pkg/sentry/fsimpl/proc/mounts.go @@ -16,7 +16,7 @@ package proc import "gvisor.dev/gvisor/pkg/sentry/kernel" -// TODO(b/138862512): Implement mountInfoFile and mountsFile. +// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile. // mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo. // diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index c46e05c3a..0d87be52b 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -229,6 +229,10 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(buf, "Mems_allowed:\t1\n") + fmt.Fprintf(buf, "Mems_allowed_list:\t0\n") return nil } diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index f989f2f8b..359468ccc 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "hostcpu", srcs = [ "getcpu_amd64.s", + "getcpu_arm64.s", "hostcpu.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s new file mode 100644 index 000000000..caf9abb89 --- /dev/null +++ b/pkg/sentry/hostcpu/getcpu_arm64.s @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for +// the lack of an optimazed way of getting the current CPU number on arm64. + +// func GetCPU() (cpu uint32) +TEXT ·GetCPU(SB), NOSPLIT, $0-4 + MOVW ZR, cpu+0(FP) + MOVD $cpu+0(FP), R0 + MOVD $0x0, R1 // unused + MOVD $0x0, R2 // unused + MOVD $0xA8, R8 // SYS_GETCPU + SVC + RET diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 184b566d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,10 +1,10 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "inet", srcs = [ @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80f227dbe..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. @@ -153,3 +165,23 @@ type Route struct { // GatewayAddr is the route gateway address (RTA_GATEWAY). GatewayAddr []byte } + +// Below SNMP metrics are from Linux/usr/include/linux/snmp.h. + +// StatSNMPIP describes Ip line of /proc/net/snmp. +type StatSNMPIP [19]uint64 + +// StatSNMPICMP describes Icmp line of /proc/net/snmp. +type StatSNMPICMP [27]uint64 + +// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp. +type StatSNMPICMPMSG [512]uint64 + +// StatSNMPTCP describes Tcp line of /proc/net/snmp. +type StatSNMPTCP [15]uint64 + +// StatSNMPUDP describes Udp line of /proc/net/snmp. +type StatSNMPUDP [8]uint64 + +// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. +type StatSNMPUDPLite [8]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e61d39c82..2706927ff 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,10 +1,11 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") - go_template_instance( name = "pending_signals_list", out = "pending_signals_list.go", @@ -34,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -83,6 +84,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "uncaught_signal_cc_proto", + visibility = ["//visibility:public"], + deps = [":uncaught_signal_proto"], +) + go_proto_library( name = "uncaught_signal_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", @@ -144,6 +151,7 @@ go_library( "threads.go", "timekeeper.go", "timekeeper_state.go", + "tty.go", "uts_namespace.go", "vdso.go", "version.go", @@ -201,12 +209,12 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", + "//pkg/syncutil", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 1d00a6310..04c244447 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -1,14 +1,14 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "atomicptr_credentials", out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "Credentials", }, diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index e3f5b0d83..3c9dceaba 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,6 +15,8 @@ package kernel import ( + "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task { return nil } +// Deadline implements context.Context.Deadline. +func (*Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (*Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (*Task) Err() error { + return nil +} + // AsyncContext returns a context.Context that may be used by goroutines that // do work on behalf of t and therefore share its contextual values, but are // not t's task goroutine (e.g. asynchronous I/O). @@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool { return ctx.t.IsLogging(level) } +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + // Value implements context.Context.Value. func (ctx taskAsyncContext) Value(key interface{}) interface{} { return ctx.t.Value(key) diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index bec13a3d9..3a88a585c 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "contexttest", testonly = 1, diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index f46c43128..3361e8b7d 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "epoll_list", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 1c5f979d4..e65b961e8 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "eventfd", diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 5eddca115..49d81b712 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "fasync", srcs = ["fasync.go"], diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index cc3f43a45..11f613a11 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -81,6 +81,9 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` + // next is start position to find fd. + next int32 + // used contains the number of non-nil entries. It must be accessed // atomically. It may be read atomically without holding mu (but not // written). @@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.mu.Lock() defer f.mu.Unlock() + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { @@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, syscall.EMFILE } + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + return fds, nil } @@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File { f.mu.Lock() defer f.mu.Unlock() + // Update current available position. + if fd < f.next { + f.next = fd + } + orig, _, _ := f.get(fd) if orig != nil { orig.IncRef() // Reference for caller. @@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) { f.forEach(func(fd int32, file *fs.File, flags FDFlags) { if cond(file, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. + // Update current available position. + if fd < f.next { + f.next = fd + } } }) } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2413788e7..2bcb6216a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) { if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) } + + i := int32(2) + fdTable.Remove(i) + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { + t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) + } + }) +} + +func TestFDTableOverLimit(t *testing.T) { + runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { + if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") + } + + if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { + t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) + } else { + for _, fd := range fds { + fdTable.Remove(fd) + } + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { + t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) + } + + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { + t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) + } else if len(fds) != 1 || fds[0] != 0 { + t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) + } }) } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 6a31dc044..75ec31761 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,14 +1,15 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "atomicptr_bucket", out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "bucket", }, diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8c1f79ab5..bd3fb4c03 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -24,6 +24,7 @@ // TaskSet.mu // SignalHandlers.mu // Task.mu +// runningTasksMu // // Locking SignalHandlers.mu in multiple SignalHandlers requires locking // TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same @@ -135,6 +136,22 @@ type Kernel struct { // syslog is the kernel log. syslog syslog + // runningTasksMu synchronizes disable/enable of cpuClockTicker when + // the kernel is idle (runningTasks == 0). + // + // runningTasksMu is used to exclude critical sections when the timer + // disables itself and when the first active task enables the timer, + // ensuring that tasks always see a valid cpuClock value. + runningTasksMu sync.Mutex `state:"nosave"` + + // runningTasks is the total count of tasks currently in + // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are + // not blocked or stopped. + // + // runningTasks must be accessed atomically. Increments from 0 to 1 are + // further protected by runningTasksMu (see incRunningTasks). + runningTasks int64 + // cpuClock is incremented every linux.ClockTick. cpuClock is used to // measure task CPU usage, since sampling monotonicClock twice on every // syscall turns out to be unreasonably expensive. This is similar to how @@ -150,6 +167,22 @@ type Kernel struct { // cpuClockTicker increments cpuClock. cpuClockTicker *ktime.Timer `state:"nosave"` + // cpuClockTickerDisabled indicates that cpuClockTicker has been + // disabled because no tasks are running. + // + // cpuClockTickerDisabled is protected by runningTasksMu. + cpuClockTickerDisabled bool + + // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the + // point it was disabled. It is cached here to avoid a lock ordering + // violation with cpuClockTicker.mu when runningTaskMu is held. + // + // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is + // true. + // + // cpuClockTickerSetting is protected by runningTasksMu. + cpuClockTickerSetting ktime.Setting + // fdMapUids is an ever-increasing counter for generating FDTable uids. // // fdMapUids is mutable, and is accessed using atomic memory operations. @@ -358,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(w, k.FeatureSet(), nil); err != nil { + if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) @@ -366,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Save(w, k, &stats); err != nil { + if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { return err } log.Infof("Kernel save stats: %s", &stats) @@ -374,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(w); err != nil { + if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -509,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(r, &features, nil); err != nil { + if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -525,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Load(r, k, &stats); err != nil { + if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { return err } log.Infof("Kernel load stats: %s", &stats) @@ -533,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(r); err != nil { + if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -771,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: mounts, + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: true, + Filename: args.Filename, + File: args.File, + CloseOnExec: false, + Argv: args.Argv, + Envv: args.Envv, + Features: k.featureSet, + } - tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + tc, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -795,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, AbstractSocketNamespace: args.AbstractSocketNamespace, ContainerID: args.ContainerID, } - if _, err := k.tasks.NewTask(config); err != nil { + t, err := k.tasks.NewTask(config) + if err != nil { return nil, 0, err } + t.traceExecEvent(tc) // Simulate exec for tracing. // Success. tgid := k.tasks.Root.IDOfThreadGroup(tg) @@ -912,6 +960,102 @@ func (k *Kernel) resumeTimeLocked() { } } +func (k *Kernel) incRunningTasks() { + for { + tasks := atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // Standard case. Simply increment. + if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) { + continue + } + return + } + + // Transition from 0 -> 1. Synchronize with other transitions and timer. + k.runningTasksMu.Lock() + tasks = atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // We're no longer the first task, no need to + // re-enable. + atomic.AddInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + if !k.cpuClockTickerDisabled { + // Timer was never disabled. + atomic.StoreInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + // We need to update cpuClock for all of the ticks missed while we + // slept, and then re-enable the timer. + // + // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify + // always increments cpuClock by 1 regardless of the number of + // expirations as a heuristic to avoid over-accounting in cases of CPU + // throttling. + // + // We want to cover the normal case, when all time should be accounted, + // so we increment for all expirations. Throttling is less concerning + // here because the ticker is only disabled from Notify. This means + // that Notify must schedule and compensate for the throttled period + // before the timer is disabled. Throttling while the timer is disabled + // doesn't matter, as nothing is running or reading cpuClock anyways. + // + // S/R also adds complication, as there are two cases. Recall that + // monotonicClock will jump forward on restore. + // + // 1. If the ticker is enabled during save, then on Restore Notify is + // called with many expirations, covering the time jump, but cpuClock + // is only incremented by 1. + // + // 2. If the ticker is disabled during save, then after Restore the + // first wakeup will call this function and cpuClock will be + // incremented by the number of expirations across the S/R. + // + // These cause very different value of cpuClock. But again, since + // nothing was running while the ticker was disabled, those differences + // don't matter. + setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now()) + if exp > 0 { + atomic.AddUint64(&k.cpuClock, exp) + } + + // Now that cpuClock is updated it is safe to allow other tasks to + // transition to running. + atomic.StoreInt64(&k.runningTasks, 1) + + // N.B. we must unlock before calling Swap to maintain lock ordering. + // + // cpuClockTickerDisabled need not wait until after Swap to become + // true. It is sufficient that the timer *will* be enabled. + k.cpuClockTickerDisabled = false + k.runningTasksMu.Unlock() + + // This won't call Notify (unless it's been ClockTick since setting.At + // above). This means we skip the thread group work in Notify. However, + // since nothing was running while we were disabled, none of the timers + // could have expired. + k.cpuClockTicker.Swap(setting) + + return + } +} + +func (k *Kernel) decRunningTasks() { + tasks := atomic.AddInt64(&k.runningTasks, -1) + if tasks < 0 { + panic(fmt.Sprintf("Invalid running count %d", tasks)) + } + + // Nothing to do. The next CPU clock tick will disable the timer if + // there is still nothing running. This provides approximately one tick + // of slack in which we can switch back and forth between idle and + // active without an expensive transition. +} + // WaitExited blocks until all tasks in k have exited. func (k *Kernel) WaitExited() { k.tasks.liveGoroutines.Wait() @@ -976,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { return lastErr } +// RebuildTraceContexts rebuilds the trace context for all tasks. +// +// Unfortunately, if these are built while tracing is not enabled, then we will +// not have meaningful trace data. Rebuilding here ensures that we can do so +// after tracing has been enabled. +func (k *Kernel) RebuildTraceContexts() { + k.extMu.Lock() + defer k.extMu.Unlock() + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + + for t, tid := range k.tasks.Root.tids { + t.rebuildTraceContext(tid) + } +} + // FeatureSet returns the FeatureSet. func (k *Kernel) FeatureSet() *cpuid.FeatureSet { return k.featureSet @@ -1180,6 +1340,7 @@ func (k *Kernel) ListSockets() []*SocketEntry { return socks } +// supervisorContext is a privileged context. type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index ebcfaa619..d7a7d1169 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -24,6 +25,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "memory_events_cc_proto", + visibility = ["//visibility:public"], + deps = [":memory_events_proto"], +) + go_proto_library( name = "memory_events_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4d15cca85..9d34f6d4d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "buffer_list", @@ -23,8 +24,10 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_util.go", "reader.go", "reader_writer.go", + "vfs.go", "writer.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", @@ -39,6 +42,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index a2dc72204..4a19ab7ce 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi switch { case flags.Read && !flags.Write: // O_RDONLY. r := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) + newHandleLocked(&i.rWakeup) if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { - if !i.waitFor(&i.wWakeup, ctx) { + if !waitFor(&i.mu, &i.wWakeup, ctx) { r.DecRef() return nil, syserror.ErrInterrupted } @@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Write && !flags.Read: // O_WRONLY. w := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.wWakeup) if i.p.isNamed && !i.p.HasReaders() { // On a nonblocking, write-only open, the open fails with ENXIO if the @@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return nil, syserror.ENXIO } - if !i.waitFor(&i.rWakeup, ctx) { + if !waitFor(&i.mu, &i.rWakeup, ctx) { w.DecRef() return nil, syserror.ErrInterrupted } @@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Read && flags.Write: // O_RDWR. // Pipes opened for read-write always succeeds without blocking. rw := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.rWakeup) + newHandleLocked(&i.wWakeup) return rw, nil default: @@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi } } -// waitFor blocks until the underlying pipe has at least one reader/writer is -// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this -// function will block for either readers or writers, depending on where -// 'wakeupChan' points. -// -// f.mu must be held by the caller. waitFor returns with f.mu held, but it will -// drop f.mu before blocking for any reader/writers. -func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { - // Ideally this function would simply use a condition variable. However, the - // wait needs to be interruptible via 'sleeper', so we must sychronize via a - // channel. The synchronization below relies on the fact that closing a - // channel unblocks all receives on the channel. - - // Does an appropriate wakeup channel already exist? If not, create a new - // one. This is all done under f.mu to avoid races. - if *wakeupChan == nil { - *wakeupChan = make(chan struct{}) - } - - // Grab a local reference to the wakeup channel since it may disappear as - // soon as we drop f.mu. - wakeup := *wakeupChan - - // Drop the lock and prepare to sleep. - i.mu.Unlock() - cancel := sleeper.SleepStart() - - // Wait for either a new reader/write to be signalled via 'wakeup', or - // for the sleep to be cancelled. - select { - case <-wakeup: - sleeper.SleepFinish(true) - case <-cancel: - sleeper.SleepFinish(false) - } - - // Take the lock and check if we were woken. If we were woken and - // interrupted, the former takes priority. - i.mu.Lock() - select { - case <-wakeup: - return true - default: - return false - } -} - -// newHandleLocked signals a new pipe reader or writer depending on where -// 'wakeupChan' points. This unblocks any corresponding reader or writer -// waiting for the other end of the channel to be opened, see Fifo.waitFor. -// -// i.mu must be held. -func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) { - if *wakeupChan != nil { - close(*wakeupChan) - *wakeupChan = nil - } -} - func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error { return syserror.EPIPE } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index adbad7764..16fa80abe 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize, usermem.PageSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize, usermem.PageSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..1a1b38f83 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -99,7 +98,7 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // // N.B. The size and atomicIOBytes will be bounded. -func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } @@ -112,17 +111,33 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) if atomicIOBytes > sizeBytes { atomicIOBytes = sizeBytes } - return &Pipe{ - isNamed: isNamed, - max: sizeBytes, - atomicIOBytes: atomicIOBytes, + var p Pipe + initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + return &p +} + +func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { + if sizeBytes < MinimumPipeSize { + sizeBytes = MinimumPipeSize + } + if sizeBytes > MaximumPipeSize { + sizeBytes = MaximumPipeSize + } + if atomicIOBytes <= 0 { + atomicIOBytes = 1 } + if atomicIOBytes > sizeBytes { + atomicIOBytes = sizeBytes + } + pipe.isNamed = isNamed + pipe.max = sizeBytes + pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes) + p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -173,13 +188,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +222,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +235,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +255,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +316,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +335,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go new file mode 100644 index 000000000..ef9641e6a --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -0,0 +1,213 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "io" + "math" + "sync" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains Pipe file functionality that is tied to neither VFS nor +// the old fs architecture. + +// Release cleans up the pipe's state. +func (p *Pipe) Release() { + p.rClose() + p.wClose() + + // Wake up readers and writers. + p.Notify(waiter.EventIn | waiter.EventOut) +} + +// Read reads from the Pipe into dst. +func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { + n, err := p.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo writes to w from the Pipe. +func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return p.dup(ctx, ops) + } + n, err := p.read(ctx, ops) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// Write writes to the Pipe from src. +func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom reads from r to the Pipe. +func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// Readiness returns the ready events in the underlying pipe. +func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask { + return p.rwReadiness() & mask +} + +// Ioctl implements ioctls on the Pipe. +func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + // Switch on ioctl request. + switch int(args[1].Int()) { + case linux.FIONREAD: + v := p.queued() + if v > math.MaxInt32 { + v = math.MaxInt32 // Silently truncate. + } + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + default: + return 0, syscall.ENOTTY + } +} + +// waitFor blocks until the underlying pipe has at least one reader/writer is +// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this +// function will block for either readers or writers, depending on where +// 'wakeupChan' points. +// +// mu must be held by the caller. waitFor returns with mu held, but it will +// drop mu before blocking for any reader/writers. +func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { + // Ideally this function would simply use a condition variable. However, the + // wait needs to be interruptible via 'sleeper', so we must sychronize via a + // channel. The synchronization below relies on the fact that closing a + // channel unblocks all receives on the channel. + + // Does an appropriate wakeup channel already exist? If not, create a new + // one. This is all done under f.mu to avoid races. + if *wakeupChan == nil { + *wakeupChan = make(chan struct{}) + } + + // Grab a local reference to the wakeup channel since it may disappear as + // soon as we drop f.mu. + wakeup := *wakeupChan + + // Drop the lock and prepare to sleep. + mu.Unlock() + cancel := sleeper.SleepStart() + + // Wait for either a new reader/write to be signalled via 'wakeup', or + // for the sleep to be cancelled. + select { + case <-wakeup: + sleeper.SleepFinish(true) + case <-cancel: + sleeper.SleepFinish(false) + } + + // Take the lock and check if we were woken. If we were woken and + // interrupted, the former takes priority. + mu.Lock() + select { + case <-wakeup: + return true + default: + return false + } +} + +// newHandleLocked signals a new pipe reader or writer depending on where +// 'wakeupChan' points. This unblocks any corresponding reader or writer +// waiting for the other end of the channel to be opened, see Fifo.waitFor. +// +// Precondition: the mutex protecting wakeupChan must be held. +func newHandleLocked(wakeupChan *chan struct{}) { + if *wakeupChan != nil { + close(*wakeupChan) + *wakeupChan = nil + } +} diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..b4d29fc77 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,16 +15,13 @@ package pipe import ( - "math" - "syscall" + "io" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) // ReaderWriter satisfies the FileOperations interface and services both @@ -44,53 +41,27 @@ type ReaderWriter struct { *Pipe } -// Release implements fs.FileOperations.Release. -func (rw *ReaderWriter) Release() { - rw.Pipe.rClose() - rw.Pipe.wClose() - - // Wake up readers and writers. - rw.Pipe.Notify(waiter.EventIn | waiter.EventOut) -} - // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.Read(ctx, dst) +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + return rw.Pipe.WriteTo(ctx, w, count, dup) } // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err + return rw.Pipe.Write(ctx, src) } -// Readiness returns the ready events in the underlying pipe. -func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask { - return rw.Pipe.rwReadiness() & mask +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + return rw.Pipe.ReadFrom(ctx, r, count) } // Ioctl implements fs.FileOperations.Ioctl. func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // Switch on ioctl request. - switch int(args[1].Int()) { - case linux.FIONREAD: - v := rw.queued() - if v > math.MaxInt32 { - v = math.MaxInt32 // Silently truncate. - } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - default: - return 0, syscall.ENOTTY - } + return rw.Pipe.Ioctl(ctx, io, args) } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go new file mode 100644 index 000000000..6416e0dd8 --- /dev/null +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains types enabling the pipe package to be used with the vfs +// package. + +// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should +// not be copied. +type VFSPipe struct { + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // pipe is the underlying pipe. + pipe Pipe + + // Channels for synchronizing the creation of new readers and writers + // of this fifo. See waitFor and newHandleLocked. + // + // These are not saved/restored because all waiters are unblocked on + // save, and either automatically restart (via ERESTARTSYS) or return + // EINTR on resume. On restarts via ERESTARTSYS, the appropriate + // channel will be recreated. + rWakeup chan struct{} `state:"nosave"` + wWakeup chan struct{} `state:"nosave"` +} + +// NewVFSPipe returns an initialized VFSPipe. +func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe { + var vp VFSPipe + initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes) + return &vp +} + +// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics +// during open: +// +// "Normally, opening the FIFO blocks until the other end is opened also. A +// process can open a FIFO in nonblocking mode. In this case, opening for +// read-only will succeed even if no-one has opened on the write side yet, +// opening for write-only will fail with ENXIO (no such device or address) +// unless the other end has already been opened. Under Linux, opening a FIFO +// for read and write will succeed both in blocking and nonblocking mode. POSIX +// leaves this behavior undefined. This can be used to open a FIFO for writing +// while there are no readers available." - fifo(7) +func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + vp.mu.Lock() + defer vp.mu.Unlock() + + readable := vfs.MayReadFileWithOpenFlags(flags) + writable := vfs.MayWriteFileWithOpenFlags(flags) + if !readable && !writable { + return nil, syserror.EINVAL + } + + vfd, err := vp.open(rp, vfsd, vfsfd, flags) + if err != nil { + return nil, err + } + + switch { + case readable && writable: + // Pipes opened for read-write always succeed without blocking. + newHandleLocked(&vp.rWakeup) + newHandleLocked(&vp.wWakeup) + + case readable: + newHandleLocked(&vp.rWakeup) + // If this pipe is being opened as nonblocking and there's no + // writer, we have to wait for a writer to open the other end. + if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + return nil, syserror.EINTR + } + + case writable: + newHandleLocked(&vp.wWakeup) + + if !vp.pipe.HasReaders() { + // Nonblocking, write-only opens fail with ENXIO when + // the read side isn't open yet. + if flags&linux.O_NONBLOCK != 0 { + return nil, syserror.ENXIO + } + // Wait for a reader to open the other end. + if !waitFor(&vp.mu, &vp.rWakeup, ctx) { + return nil, syserror.EINTR + } + } + + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return vfd, nil +} + +// Preconditions: vp.mu must be held. +func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + var fd VFSPipeFD + fd.flags = flags + fd.readable = vfs.MayReadFileWithOpenFlags(flags) + fd.writable = vfs.MayWriteFileWithOpenFlags(flags) + fd.vfsfd = vfsfd + fd.pipe = &vp.pipe + if fd.writable { + // The corresponding Mount.EndWrite() is in VFSPipe.Release(). + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + } + + switch { + case fd.readable && fd.writable: + vp.pipe.rOpen() + vp.pipe.wOpen() + case fd.readable: + vp.pipe.rOpen() + case fd.writable: + vp.pipe.wOpen() + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return &fd, nil +} + +// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is +// expected that filesystesm will use this in a struct implementing +// vfs.FileDescriptionImpl. +type VFSPipeFD struct { + pipe *Pipe + flags uint32 + readable bool + writable bool + vfsfd *vfs.FileDescription +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *VFSPipeFD) Release() { + var event waiter.EventMask + if fd.readable { + fd.pipe.rClose() + event |= waiter.EventIn + } + if fd.writable { + fd.pipe.wClose() + event |= waiter.EventOut + } + if event == 0 { + panic("invalid pipe flags: must be readable, writable, or both") + } + + if fd.writable { + fd.vfsfd.VirtualDentry().Mount().EndWrite() + } + + fd.pipe.Notify(event) +} + +// OnClose implements vfs.FileDescriptionImpl.OnClose. +func (fd *VFSPipeFD) OnClose(_ context.Context) error { + return nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + if !fd.readable { + return 0, syserror.EINVAL + } + + return fd.pipe.Read(ctx, dst) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + if !fd.writable { + return 0, syserror.EINVAL + } + + return fd.pipe.Write(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.pipe.Ioctl(ctx, uio, args) +} diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go index c5d095af7..2e861a5a8 100644 --- a/pkg/sentry/kernel/posixtimer.go +++ b/pkg/sentry/kernel/posixtimer.go @@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() { } // Notify implements ktime.TimerListener.Notify. -func (it *IntervalTimer) Notify(exp uint64) { +func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { if it.target == nil { - return + return ktime.Setting{}, false } it.target.tg.pidns.owner.mu.RLock() @@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) { if it.sigpending { it.overrunCur += exp - return + return ktime.Setting{}, false } // sigpending must be set before sendSignalTimerLocked() so that it can be @@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) { if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil { it.signalRejectedLocked() } + + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. Users of Timer should call diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index 0acdf769d..61e412911 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -17,7 +17,6 @@ package kernel import ( - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 1725b8562..98ea7a0d8 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 36edf10f3..f4c00cd86 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "waiter_list", diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 93fe68a3e..de9617e9d 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred return syserror.ERANGE } - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = val sem.pid = pid s.changeTime = ktime.NowFromContext(ctx) @@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti for i, val := range vals { sem := &s.sems[i] - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = int16(val) sem.pid = pid sem.wakeWaiters() @@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch } // All operations succeeded, apply them. - // TODO(b/29354920): handle undo operations. + // TODO(gvisor.dev/issue/137): handle undo operations. for i, v := range tmpVals { s.sems[i].value = v s.sems[i].wakeWaiters() diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 81fcd8258..047b5214d 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -47,6 +47,11 @@ type Session struct { // The id is immutable. id SessionID + // foreground is the foreground process group. + // + // This is protected by TaskSet.mu. + foreground *ProcessGroup + // ProcessGroups is a list of process groups in this Session. This is // protected by TaskSet.mu. processGroups processGroupList @@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { func (tg *ThreadGroup) CreateSession() error { tg.pidns.owner.mu.Lock() defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() return tg.createSession() } // createSession creates a new session for a threadgroup. // -// Precondition: callers must hold TaskSet.mu for writing. +// Precondition: callers must hold TaskSet.mu and the signal mutex for writing. func (tg *ThreadGroup) createSession() error { // Get the ID for this thread in the current namespace. id := tg.pidns.tgids[tg] @@ -321,8 +328,14 @@ func (tg *ThreadGroup) createSession() error { childTG.processGroup.incRefWithParent(pg) childTG.processGroup.decRefWithParent(oldParentPG) }) - tg.processGroup.decRefWithParent(oldParentPG) + // If tg.processGroup is an orphan, decRefWithParent will lock + // the signal mutex of each thread group in tg.processGroup. + // However, tg's signal mutex may already be locked at this + // point. We change tg's process group before calling + // decRefWithParent to avoid locking tg's signal mutex twice. + oldPG := tg.processGroup tg.processGroup = pg + oldPG.decRefWithParent(oldParentPG) } else { // The current process group may be nil only in the case of an // unparented thread group (i.e. the init process). This would @@ -346,6 +359,9 @@ func (tg *ThreadGroup) createSession() error { ns.processGroups[ProcessGroupID(local)] = pg } + // Disconnect from the controlling terminal. + tg.tty = nil + return nil } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index aa7471eb6..cd48945e6 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "shm", srcs = [ diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD new file mode 100644 index 000000000..9f7e19b4d --- /dev/null +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -0,0 +1,22 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "signalfd", + srcs = ["signalfd.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fs/anon", + "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", + "//pkg/sentry/usermem", + "//pkg/syserror", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go new file mode 100644 index 000000000..4b08d7d72 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -0,0 +1,140 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package signalfd provides an implementation of signal file descriptors. +package signalfd + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/anon" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SignalOperations represent a file with signalfd semantics. +// +// +stateify savable +type SignalOperations struct { + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FilePipeSeek `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoFsync `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + // target is the original task target. + // + // The semantics here are a bit broken. Linux will always use current + // for all reads, regardless of where the signalfd originated. We can't + // do exactly that because we need to plumb the context through + // EventRegister in order to support proper blocking behavior. This + // will undoubtedly become very complicated quickly. + target *kernel.Task + + // mu protects below. + mu sync.Mutex `state:"nosave"` + + // mask is the signal mask. Protected by mu. + mask linux.SignalSet +} + +// New creates a new signalfd object with the supplied mask. +func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // No task context? Not valid. + return nil, syserror.EINVAL + } + // name matches fs/signalfd.c:signalfd4. + dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]") + return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{ + target: t, + mask: mask, + }), nil +} + +// Release implements fs.FileOperations.Release. +func (s *SignalOperations) Release() {} + +// Mask returns the signal mask. +func (s *SignalOperations) Mask() linux.SignalSet { + s.mu.Lock() + mask := s.mask + s.mu.Unlock() + return mask +} + +// SetMask sets the signal mask. +func (s *SignalOperations) SetMask(mask linux.SignalSet) { + s.mu.Lock() + s.mask = mask + s.mu.Unlock() +} + +// Read implements fs.FileOperations.Read. +func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { + // Attempt to dequeue relevant signals. + info, err := s.target.Sigtimedwait(s.Mask(), 0) + if err != nil { + // There must be no signal available. + return 0, syserror.ErrWouldBlock + } + + // Copy out the signal info using the specified format. + var buf [128]byte + binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + Signo: uint32(info.Signo), + Errno: info.Errno, + Code: info.Code, + PID: uint32(info.Pid()), + UID: uint32(info.Uid()), + Status: info.Status(), + Overrun: uint32(info.Overrun()), + Addr: info.Addr(), + }) + n, err := dst.CopyOut(ctx, buf[:]) + return int64(n), err +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 && s.target.PendingSignals()&s.Mask() != 0 { + return waiter.EventIn // Pending signals. + } + return 0 +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) { + // Register for the signal set; ignore the passed events. + s.target.SignalRegister(entry, waiter.EventMask(s.Mask())) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SignalOperations) EventUnregister(entry *waiter.Entry) { + // Unregister the original entry. + s.target.SignalUnregister(entry) +} diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 220fa73a2..2fdee0282 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn { return nil } +// LookupName looks up a syscall name. +func (s *SyscallTable) LookupName(sysno uintptr) string { + if sc, ok := s.Table[sysno]; ok { + return sc.Name + } + return fmt.Sprintf("sys_%d", sysno) // Unlikely. +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e91f82bb3..ab0c6c4aa 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -15,6 +15,8 @@ package kernel import ( + gocontext "context" + "runtime/trace" "sync" "sync/atomic" @@ -35,7 +37,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/waiter" ) // Task represents a thread of execution in the untrusted app. It @@ -82,7 +85,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq gvsync.SeqCount `state:"nosave"` + goschedSeq syncutil.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called @@ -133,6 +136,13 @@ type Task struct { // signalStack is exclusive to the task goroutine. signalStack arch.SignalStack + // signalQueue is a set of registered waiters for signal-related events. + // + // signalQueue is protected by the signalMutex. Note that the task does + // not implement all queue methods, specifically the readiness checks. + // The task only broadcast a notification on signal delivery. + signalQueue waiter.Queue `state:"zerovalue"` + // If groupStopPending is true, the task should participate in a group // stop in the interrupt path. // @@ -382,7 +392,14 @@ type Task struct { // logPrefix is a string containing the task's thread ID in the root PID // namespace, and is prepended to log messages emitted by Task.Infof etc. - logPrefix atomic.Value `state:".(string)"` + logPrefix atomic.Value `state:"nosave"` + + // traceContext and traceTask are both used for tracing, and are + // updated along with the logPrefix in updateInfoLocked. + // + // These are exclusive to the task goroutine. + traceContext gocontext.Context `state:"nosave"` + traceTask *trace.Task `state:"nosave"` // creds is the task's credentials. // @@ -520,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) { t.ptraceTracer.Store(tracer) } -func (t *Task) saveLogPrefix() string { - return t.logPrefix.Load().(string) -} - -func (t *Task) loadLogPrefix(prefix string) { - t.logPrefix.Store(prefix) -} - func (t *Task) saveSyscallFilters() []bpf.Program { if f := t.syscallFilters.Load(); f != nil { return f.([]bpf.Program) @@ -541,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) { // afterLoad is invoked by stateify. func (t *Task) afterLoad() { + t.updateInfoLocked() t.interruptChan = make(chan struct{}, 1) t.gosched.State = TaskGoroutineNonexistent if t.stop != nil { @@ -701,9 +711,9 @@ func (t *Task) FDTable() *FDTable { return t.fdTable } -// GetFile is a convenience wrapper t.FDTable().GetFile. +// GetFile is a convenience wrapper for t.FDTable().Get. // -// Precondition: same as FDTable. +// Precondition: same as FDTable.Get. func (t *Task) GetFile(fd int32) *fs.File { f, _ := t.fdTable.Get(fd) return f diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index dd69939f9..4a4a69ee2 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -16,6 +16,7 @@ package kernel import ( "runtime" + "runtime/trace" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { runtime.Gosched() } + region := trace.StartRegion(t.traceContext, blockRegion) select { case <-C: + region.End() t.SleepFinish(true) + // Woken by event. return nil case <-interrupt: + region.End() t.SleepFinish(false) // Return the indicated error on interrupt. return syserror.ErrInterrupted case <-timerChan: - // We've timed out. + region.End() t.SleepFinish(true) + // We've timed out. return syserror.ETIMEDOUT } } diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 0916fd658..3eadfedb4 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) + t.traceCloneEvent(tid) // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 8639d379f..bb5560acf 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,10 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} } -// LoadTaskImage loads filename into a new TaskContext. +// LoadTaskImage loads a specified file into a new TaskContext. // -// It takes several arguments: -// * mounts: MountNamespace to lookup filename in -// * root: Root to lookup filename under -// * wd: Working directory to lookup filename under -// * maxTraversals: maximum number of symlinks to follow -// * filename: path to binary to load -// * file: an open fs.File object of the binary to load. If set, -// file will be loaded and not filename. -// * argv: Binary argv -// * envv: Binary envv -// * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving filename. - if file != nil { - filename = file.MappedName(ctx) +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.MappedName(ctx) } // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) + args.MemoryManager = m - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 17a089b90..90a6190f1 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct { } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { + t.traceExecEvent(r.tc) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { @@ -253,7 +254,7 @@ func (t *Task) promoteLocked() { t.tg.leader = t t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t]) - t.updateLogPrefixLocked() + t.updateInfoLocked() // Reap the original leader. If it has a tracer, detach it instead of // waiting for it to acknowledge the original leader's death. oldLeader.exitParentNotified = true diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 535f03e50..435761e5a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState { type runExitMain struct{} func (*runExitMain) execute(t *Task) taskRunState { + t.traceExitEvent() lastExiter := t.exitThreadGroup() // If the task has a cleartid, and the thread group wasn't killed by a diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 78ff14b20..ce3e6ef28 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) { // disables the features we don't support anyway, is always set. This // drastically simplifies this function. // -// - We don't implement AT_SECURE, because no_new_privs always being set means -// that the conditions that require AT_SECURE never arise. (Compare Linux's +// - We don't set AT_SECURE = 1, because no_new_privs always being set means +// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's // security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().) // // - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index a29e9b9eb..0fb3661de 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -16,6 +16,7 @@ package kernel import ( "fmt" + "runtime/trace" "sort" "gvisor.dev/gvisor/pkg/log" @@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() { } } -// updateLogPrefix updates the task's cached log prefix to reflect its -// current thread ID. +// trace definitions. +// +// Note that all region names are prefixed by ':' in order to ensure that they +// are lexically ordered before all system calls, which use the naked system +// call name (e.g. "read") for maximum clarity. +const ( + traceCategory = "task" + runRegion = ":run" + blockRegion = ":block" + cpuidRegion = ":cpuid" + faultRegion = ":fault" +) + +// updateInfoLocked updates the task's cached log prefix and tracing +// information to reflect its current thread ID. // // Preconditions: The task's owning TaskSet.mu must be locked. -func (t *Task) updateLogPrefixLocked() { +func (t *Task) updateInfoLocked() { // Use the task's TID in the root PID namespace for logging. - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t])) + tid := t.tg.pidns.owner.Root.tids[t] + t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.rebuildTraceContext(tid) +} + +// rebuildTraceContext rebuilds the trace context. +// +// Precondition: the passed tid must be the tid in the root namespace. +func (t *Task) rebuildTraceContext(tid ThreadID) { + // Re-initialize the trace context. + if t.traceTask != nil { + t.traceTask.End() + } + + // Note that we define the "task type" to be the dynamic TID. This does + // not align perfectly with the documentation for "tasks" in the + // tracing package. Tasks may be assumed to be bounded by analysis + // tools. However, if we just use a generic "task" type here, then the + // "user-defined tasks" page on the tracing dashboard becomes nearly + // unusable, as it loads all traces from all tasks. + // + // We can assume that the number of tasks in the system is not + // arbitrarily large (in general it won't be, especially for cases + // where we're collecting a brief profile), so using the TID is a + // reasonable compromise in this case. + t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) +} + +// traceCloneEvent is called when a new task is spawned. +// +// ntid must be the new task's ThreadID in the root namespace. +func (t *Task) traceCloneEvent(ntid ThreadID) { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid) +} + +// traceExitEvent is called when a task exits. +func (t *Task) traceExitEvent() { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status()) +} + +// traceExecEvent is called when a task calls exec. +func (t *Task) traceExecEvent(tc *TaskContext) { + if !trace.IsEnabled() { + return + } + d := tc.MemoryManager.Executable() + if d == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") + return + } + defer d.DecRef() + root := t.fsContext.RootDirectory() + if root == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>") + return + } + defer root.DecRef() + n, _ := d.FullName(root) + trace.Logf(t.traceContext, traceCategory, "exec: %s", n) } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index c92266c59..d97f8c189 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -17,6 +17,7 @@ package kernel import ( "bytes" "runtime" + "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.RUnlock() } + region := trace.StartRegion(t.traceContext, runRegion) t.accountTaskGoroutineEnter(TaskGoroutineRunningApp) info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU) t.accountTaskGoroutineLeave(TaskGoroutineRunningApp) + region.End() if clearSinglestep { t.Arch().ClearSingleStep() @@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState { case platform.ErrContextSignalCPUID: // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) @@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() // Resume execution. return (*runApp)(nil) } + region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. @@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState { // an application-generated signal and we should continue execution // normally. if at.Any() { + region := trace.StartRegion(t.traceContext, faultRegion) addr := usermem.Addr(info.Addr()) err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack())) + region.End() if err == nil { // The fault was handled appropriately. // We can resume running the application. @@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState { } // Is this a vsyscall that we need emulate? + // + // Note that we don't track vsyscalls as part of a + // specific trace region. This is because regions don't + // stack, and the actual system call will count as a + // region. We should be able to easily identify + // vsyscalls by having a <fault><syscall> pair. if at.Execute { if sysno, ok := t.tc.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index e76c069b0..8b148db35 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) { t.gosched.Timestamp = now t.gosched.State = state t.goschedSeq.EndWrite() + + if state != TaskGoroutineRunningApp { + // Task is blocking/stopping. + t.k.decRunningTasks() + } } // Preconditions: The caller must be running on the task goroutine, and leaving // a state indicated by a previous call to // t.accountTaskGoroutineEnter(state). func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { + if state != TaskGoroutineRunningApp { + // Task is unblocking/continuing. + t.k.incRunningTasks() + } + now := t.k.CPUClockNow() if t.gosched.State != state { panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys)) @@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker { } // Notify implements ktime.TimerListener.Notify. -func (ticker *kernelCPUClockTicker) Notify(exp uint64) { +func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { // Only increment cpuClock by 1 regardless of the number of expirations. // This approximately compensates for cases where thread throttling or bad // Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and @@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) { tgs[i] = nil } ticker.tgs = tgs[:0] + + // If nothing is running, we can disable the timer. + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks == 0 { + ticker.k.runningTasksMu.Lock() + defer ticker.k.runningTasksMu.Unlock() + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks != 0 { + // Raced with a 0 -> 1 transition. + return setting, false + } + + // Stop the timer. We must cache the current setting so the + // kernel can access it without violating the lock order. + ticker.k.cpuClockTickerSetting = setting + ticker.k.cpuClockTickerDisabled = true + setting.Enabled = false + return setting, true + } + + return setting, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 266959a07..39cd1340d 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -28,6 +28,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // SignalAction is an internal signal action. @@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) { // // Preconditions: The signal mutex must be locked. func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { + // Notify that the signal is queued. + t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig))) + // - Do not choose tasks that are blocking the signal. if linux.SignalSetOf(sig)&t.signalMask != 0 { return false @@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() return t.deliverSignal(info, act) } + +// SignalRegister registers a waiter for pending signals. +func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventRegister(e, mask) + t.tg.signalHandlers.mu.Unlock() +} + +// SignalUnregister unregisters a waiter for pending signals. +func (t *Task) SignalUnregister(e *waiter.Entry) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventUnregister(e) + t.tg.signalHandlers.mu.Unlock() +} diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index d60cd62c7..3522a4ae5 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { // Below this point, newTask is expected not to fail (there is no rollback // of assignTIDsLocked or any of the following). - // Logging on t's behalf will panic if t.logPrefix hasn't been initialized. - // This is the earliest point at which we can do so (since t now has thread - // IDs). - t.updateLogPrefixLocked() + // Logging on t's behalf will panic if t.logPrefix hasn't been + // initialized. This is the earliest point at which we can do so + // (since t now has thread IDs). + t.updateInfoLocked() if cfg.InheritParent != nil { t.parent = cfg.InheritParent.parent @@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { if parentPG := tg.parentPG(); parentPG == nil { tg.createSession() } else { - // Inherit the process group. + // Inherit the process group and terminal. parentPG.incRefWithParent(parentPG) tg.processGroup = parentPG + tg.tty = t.parent.tg.tty } } tg.tasks.PushBack(t) diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index b543d536a..3180f5560 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -17,6 +17,7 @@ package kernel import ( "fmt" "os" + "runtime/trace" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u ctrl = ctrlStopAndReinvokeSyscall } else { fn := s.Lookup(sysno) + var region *trace.Region // Only non-nil if tracing == true. + if trace.IsEnabled() { + region = trace.StartRegion(t.traceContext, s.LookupName(sysno)) + } if fn != nil { // Call our syscall implementation. rval, ctrl, err = fn(t, args) @@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u // Use the missing function if not found. rval, err = t.SyscallTable().Missing(t, sysno, args) } + if region != nil { + region.End() + } } if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) { diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 2a97e3e8e..72568d296 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,10 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" ) // A ThreadGroup is a logical grouping of tasks that has widespread @@ -245,6 +248,12 @@ type ThreadGroup struct { // // mounts is immutable. mounts *fs.MountNamespace + + // tty is the thread group's controlling terminal. If nil, there is no + // controlling terminal. + // + // tty is protected by the signal mutex. + tty *TTY } // newThreadGroup returns a new, empty thread group in PID namespace ns. The @@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) { } } +// SetControllingTTY sets tty as the controlling terminal of tg. +func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "The calling process must be a session leader and not have a + // controlling terminal already." - tty_ioctl(4) + if tg.processGroup.session.leader != tg || tg.tty != nil { + return syserror.EINVAL + } + + // "If this terminal is already the controlling terminal of a different + // session group, then the ioctl fails with EPERM, unless the caller + // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the + // terminal is stolen, and all processes that had it as controlling + // terminal lose it." - tty_ioctl(4) + if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { + if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + return syserror.EPERM + } + // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. + for othertg := range tg.pidns.owner.Root.tgids { + // This won't deadlock by locking tg.signalHandlers + // because at this point: + // - We only lock signalHandlers if it's in the same + // session as the tty's controlling thread group. + // - We know that the calling thread group is not in + // the same session as the tty's controlling thread + // group. + if othertg.processGroup.session == tty.tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + othertg.signalHandlers.mu.Unlock() + } + } + } + + // Set the controlling terminal and foreground process group. + tg.tty = tty + tg.processGroup.session.foreground = tg.processGroup + // Set this as the controlling process of the terminal. + tty.tg = tg + + return nil +} + +// ReleaseControllingTTY gives up tty as the controlling tty of tg. +func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.RLock() + defer tg.pidns.owner.mu.RUnlock() + + // Just below, we may re-lock signalHandlers in order to send signals. + // Thus we can't defer Unlock here. + tg.signalHandlers.mu.Lock() + + if tg.tty == nil || tg.tty != tty { + tg.signalHandlers.mu.Unlock() + return syserror.ENOTTY + } + + // "If the process was session leader, then send SIGHUP and SIGCONT to + // the foreground process group and all processes in the current + // session lose their controlling terminal." - tty_ioctl(4) + // Remove tty as the controlling tty for each process in the session, + // then send them SIGHUP and SIGCONT. + + // If we're not the session leader, we don't have to do much. + if tty.tg != tg { + tg.tty = nil + tg.signalHandlers.mu.Unlock() + return nil + } + + tg.signalHandlers.mu.Unlock() + + // We're the session leader. SIGHUP and SIGCONT the foreground process + // group and remove all controlling terminals in the session. + var lastErr error + for othertg := range tg.pidns.owner.Root.tgids { + if othertg.processGroup.session == tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + if othertg.processGroup == tg.processGroup.session.foreground { + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + lastErr = err + } + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + lastErr = err + } + } + othertg.signalHandlers.mu.Unlock() + } + } + + return lastErr +} + +// ForegroundProcessGroup returns the process group ID of the foreground +// process group. +func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "When fd does not refer to the controlling terminal of the calling + // process, -1 is returned" - tcgetpgrp(3) + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + return int32(tg.processGroup.session.foreground.id), nil +} + +// SetForegroundProcessGroup sets the foreground process group of tty to pgid. +func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // TODO(b/129283598): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is + // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all + // members of this background process group." + + // tty must be the controlling terminal. + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + // pgid must be positive. + if pgid < 0 { + return -1, syserror.EINVAL + } + + // pg must not be empty. Empty process groups are removed from their + // pid namespaces. + pg, ok := tg.pidns.processGroups[pgid] + if !ok { + return -1, syserror.ESRCH + } + + // pg must be part of this process's session. + if tg.processGroup.session != pg.session { + return -1, syserror.EPERM + } + + tg.processGroup.session.foreground.id = pgid + return 0, nil +} + // itimerRealListener implements ktime.Listener for ITIMER_REAL expirations. // // +stateify savable @@ -332,8 +511,9 @@ type itimerRealListener struct { } // Notify implements ktime.TimerListener.Notify. -func (l *itimerRealListener) Notify(exp uint64) { +func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM)) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 9beae4b31..31847e1df 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "time", srcs = [ diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index aa6c75d25..107394183 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { // A TimerListener receives expirations from a Timer. type TimerListener interface { // Notify is called when its associated Timer expires. exp is the number of - // expirations. + // expirations. setting is the next timer Setting. // // Notify is called with the associated Timer's mutex locked, so Notify // must not take any locks that precede Timer.mu in lock order. // + // If Notify returns true, the timer will use the returned setting + // rather than the passed one. + // // Preconditions: exp > 0. - Notify(exp uint64) + Notify(exp uint64, setting Setting) (newSetting Setting, update bool) // Destroy is called when the timer is destroyed. Destroy() @@ -533,7 +536,9 @@ func (t *Timer) Tick() { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) } @@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, s @@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { } oldS, oldExp := t.setting.At(now) if oldExp > 0 { - t.listener.Notify(oldExp) + t.listener.Notify(oldExp, oldS) + // N.B. The returned Setting doesn't matter because we're about + // to overwrite. } if f != nil { f() @@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { newS, newExp := s.At(now) t.setting = newS if newExp > 0 { - t.listener.Notify(newExp) + if newS, ok := t.listener.Notify(newExp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, oldS @@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) { } // Notify implements ktime.TimerListener.Notify. -func (c *ChannelNotifier) Notify(uint64) { +func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) { select { case c.tchan <- struct{}{}: default: } + + return Setting{}, false } // Destroy implements ktime.TimerListener.Destroy and will close the channel. diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go new file mode 100644 index 000000000..048de26dc --- /dev/null +++ b/pkg/sentry/kernel/tty.go @@ -0,0 +1,39 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync" + +// TTY defines the relationship between a thread group and its controlling +// terminal. +// +// +stateify savable +type TTY struct { + // Index is the terminal index. It is immutable. + Index uint32 + + mu sync.Mutex `state:"nosave"` + + // tg is protected by mu. + tg *ThreadGroup +} + +// TTY returns the thread group's controlling terminal. If nil, there is no +// controlling terminal. +func (tg *ThreadGroup) TTY() *TTY { + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + return tg.tty +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 40025d62d..156e67bf8 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "limits", diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 3b322f5f3..2890393bd 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_embed_data( name = "vdso_bin", src = "//vdso:vdso.so", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index ba9c9ce12..6299a3e2f 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf. return syserror.ENOEXEC } + // N.B. Linux uses vm_brk_flags to map these pages, which only + // honors the X bit, always mapping at least RW. ignoring These + // pages are not included in the final brk region. + prot := usermem.ReadWrite + if phdr.Flags&elf.PF_X == elf.PF_X { + prot.Execute = true + } + if _, err := m.MMap(ctx, memmap.MMapOpts{ Length: uint64(anonSize), Addr: anonAddr, // Fixed without Unmap will fail the mmap if something is // already at addr. - Fixed: true, - Private: true, - // N.B. Linux uses vm_brk to map these pages, ignoring - // the segment protections, instead always mapping RW. - // These pages are not included in the final brk - // region. - Perms: usermem.ReadWrite, + Fixed: true, + Private: true, + Perms: prot, MaxPerms: usermem.AnyAccess, }); err != nil { ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err) @@ -404,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el start = vaddr } if vaddr < end { + // NOTE(b/37474556): Linux allows out-of-order + // segments, in violation of the spec. ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end) return loadedELF{}, syserror.ENOEXEC } @@ -620,15 +626,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in return loadParsedELF(ctx, m, f, info, 0) } -// loadELF loads f into the Task address space. +// loadELF loads args.File into the Task address space. // // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // // Preconditions: -// * f is an ELF file -func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) { - bin, ac, err := loadInitialELF(ctx, m, fs, f) +// * args.File is an ELF file +func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { + bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { ctx.Infof("Error loading binary: %v", err) return loadedELF{}, nil, err @@ -636,7 +642,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace var interp loadedELF if bin.interpreter != "" { - d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter) + // Even if we do not allow the final link of the script to be + // resolved, the interpreter should still be resolved if it is + // a symlink. + args.ResolveFinal = true + // Refresh the traversal limit. + *args.RemainingTraversals = linux.MaxSymlinkTraversals + args.Filename = bin.interpreter + d, i, err := openPath(ctx, args) if err != nil { ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err @@ -645,7 +658,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace // We don't need the Dirent. d.DecRef() - interp, err = loadInterpreterELF(ctx, m, i, bin) + interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin) if err != nil { ctx.Infof("Error loading interpreter: %v", err) return loadedELF{}, nil, err diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index f6f1ae762..b03eeb005 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package loader loads a binary into a MemoryManager. +// Package loader loads an executable file into a MemoryManager. package loader import ( @@ -20,6 +20,7 @@ import ( "fmt" "io" "path" + "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +36,54 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LoadArgs holds specifications for an executable file to be loaded. +type LoadArgs struct { + // MemoryManager is the memory manager to load the executable into. + MemoryManager *mm.MemoryManager + + // Mounts is the mount namespace in which to look up Filename. + Mounts *fs.MountNamespace + + // Root is the root directory under which to look up Filename. + Root *fs.Dirent + + // WorkingDirectory is the working directory under which to look up + // Filename. + WorkingDirectory *fs.Dirent + + // RemainingTraversals is the maximum number of symlinks to follow to + // resolve Filename. This counter is passed by reference to keep it + // updated throughout the call stack. + RemainingTraversals *uint + + // ResolveFinal indicates whether the final link of Filename should be + // resolved, if it is a symlink. + ResolveFinal bool + + // Filename is the path for the executable. + Filename string + + // File is an open fs.File object of the executable. If File is not + // nil, then File will be loaded and Filename will be ignored. + File *fs.File + + // CloseOnExec indicates that the executable (or one of its parent + // directories) was opened with O_CLOEXEC. If the executable is an + // interpreter script, then cause an ENOENT error to occur, since the + // script would otherwise be inaccessible to the interpreter. + CloseOnExec bool + + // Argv is the vector of arguments to pass to the executable. + Argv []string + + // Envv is the vector of environment variables to pass to the + // executable. + Envv []string + + // Features specifies the CPU feature set for the executable. + Features *cpuid.FeatureSet +} + // readFull behaves like io.ReadFull for an *fs.File. func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 @@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in return total, nil } -// openPath opens name for loading. +// openPath opens args.Filename and checks that it is valid for loading. // -// openPath returns the fs.Dirent and an *fs.File for name, which is not +// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not // installed in the Task FDTable. The caller takes ownership of both. // -// name must be a readable, executable, regular file. -func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) { - if name == "" { +// args.Filename must be a readable, executable, regular file. +func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) { + if args.Filename == "" { ctx.Infof("cannot open empty name") return nil, nil, syserror.ENOENT } - d, err := mm.FindInode(ctx, root, wd, name, maxTraversals) + var d *fs.Dirent + var err error + if args.ResolveFinal { + d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } else { + d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } if err != nil { return nil, nil, err } - - // Open file will take a reference to Dirent, so destroy this one. + // Defer a DecRef for the sake of failure cases. defer d.DecRef() - return openFile(ctx, nil, d, name) -} + if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) { + return nil, nil, syserror.ELOOP + } -// openFile performs checks on a file to be executed. If provided a *fs.File, -// openFile takes that file's Dirent and performs checks on it. If provided a -// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's -// Inode and performs checks on that. -// -// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership -// of both. -// -// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. -func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { - // file and dirent must not be nil. - if dirent == nil && file == nil { - ctx.Infof("dirent and file cannot both be nil.") - return nil, nil, syserror.ENOENT + if err := checkPermission(ctx, d); err != nil { + return nil, nil, err } - if file != nil { - dirent = file.Dirent + // If they claim it's a directory, then make sure. + // + // N.B. we reject directories below, but we must first reject + // non-directories passed as directories. + if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) { + return nil, nil, syserror.ENOTDIR } - // Perform permissions checks on the file. - if err := checkFile(ctx, dirent, name); err != nil { + if err := checkIsRegularFile(ctx, d, args.Filename); err != nil { return nil, nil, err } - if file == nil { - var ferr error - if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { - return nil, nil, ferr - } - } else { - // GetFile takes a reference to the created file, so make one in the case - // that the file reference already existed. - file.IncRef() + f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return nil, nil, err + } + // Defer a DecRef for the sake of failure cases. + defer f.DecRef() + + if err := checkPread(ctx, f, args.Filename); err != nil { + return nil, nil, err + } + + d.IncRef() + f.IncRef() + return d, f, err +} + +// checkFile performs checks on a file to be executed. +func checkFile(ctx context.Context, f *fs.File, filename string) error { + if err := checkPermission(ctx, f.Dirent); err != nil { + return err } - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) - return nil, nil, syserror.EACCES + if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil { + return err } - // Grab reference for caller. - dirent.IncRef() - return dirent, file, nil + return checkPread(ctx, f, filename) } -// checkFile performs file permissions checks for binaries called in openPath -// and openFile -func checkFile(ctx context.Context, d *fs.Dirent, name string) error { +// checkPermission checks whether the file is readable and executable. +func checkPermission(ctx context.Context, d *fs.Dirent) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error { Read: true, Execute: true, } - if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return err - } + return d.Inode.CheckPermission(ctx, perms) +} - // If they claim it's a directory, then make sure. - // - // N.B. we reject directories below, but we must first reject - // non-directories passed as directories. - if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR +// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc. +func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error { + attr := d.Inode.StableAttr + if !fs.IsRegular(attr) { + ctx.Infof("%s is not regular: %v", filename, attr) + return syserror.EACCES } + return nil +} - // No exec-ing directories, pipes, etc! - if !fs.IsRegular(d.Inode.StableAttr) { - ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) +// checkPread checks whether we can read the file at arbitrary offsets. +func checkPread(ctx context.Context, f *fs.File, filename string) error { + if !f.Flags().Pread { + ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags()) return syserror.EACCES } - return nil - } // allocStack allocates and maps a stack in to any available part of the address space. @@ -173,45 +224,49 @@ const ( maxLoaderAttempts = 6 ) -// loadBinary loads a binary that is pointed to by "file". If nil, the path -// "filename" is resolved and loaded. +// loadExecutable loads an executable that is pointed to by args.File. If nil, +// the path args.Filename is resolved and loaded. If the executable is an +// interpreter script rather than an ELF, the binary of the corresponding +// interpreter will be loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file -// * Possibly updated argv -func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +// * Possibly updated args.Argv +func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { var ( d *fs.Dirent - f *fs.File err error ) - if passedFile == nil { - d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) - + if args.File == nil { + d, args.File, err = openPath(ctx, args) + // We will return d in the successful case, but defer a DecRef for the + // sake of intermediate loops and failure cases. + if d != nil { + defer d.DecRef() + } + if args.File != nil { + defer args.File.DecRef() + } } else { - d, f, err = openFile(ctx, passedFile, nil, "") - // Set to nil in case we loop on a Interpreter Script. - passedFile = nil + d = args.File.Dirent + d.IncRef() + defer d.DecRef() + err = checkFile(ctx, args.File, args.Filename) } - if err != nil { - ctx.Infof("Error opening %s: %v", filename, err) + ctx.Infof("Error opening %s: %v", args.Filename, err) return loadedELF{}, nil, nil, nil, err } - defer f.DecRef() - // We will return d in the successful case, but defer a DecRef - // for intermediate loops and failure cases. - defer d.DecRef() // Check the header. Is this an ELF or interpreter script? var hdr [4]uint8 // N.B. We assume that reading from a regular file cannot block. - _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0) - // Allow unexpected EOF, as a valid executable could be only three - // bytes (e.g., #!a). + _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0) + // Allow unexpected EOF, as a valid executable could be only three bytes + // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { err = syserror.ENOEXEC @@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) + loaded, ac, err := loadELF(ctx, args) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err } // An ELF is always terminal. Hold on to d. d.IncRef() - return loaded, ac, d, argv, err + return loaded, ac, d, args.Argv, err case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): - newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv) + if args.CloseOnExec { + return loadedELF{}, nil, nil, nil, syserror.ENOENT + } + args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { ctx.Infof("Error loading interpreter script: %v", err) return loadedELF{}, nil, nil, nil, err } - filename = newpath - argv = newargv + // Refresh the traversal limit for the interpreter. + *args.RemainingTraversals = linux.MaxSymlinkTraversals default: ctx.Infof("Unknown magic: %v", hdr) return loadedELF{}, nil, nil, nil, syserror.ENOEXEC } + // Set to nil in case we loop on a Interpreter Script. + args.File = nil } return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads "file" into a MemoryManager. If file is nil, the path "filename" -// is resolved and loaded instead. +// Load loads args.File into a MemoryManager. If args.File is nil, the path +// args.Filename is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { - // Load the binary itself. - loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) +func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { + // Load the executable itself. + loaded, ac, d, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer d.DecRef() // Load the VDSO. - vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded) + vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the - // binary. Userspace can assume that the remainer of the page after + // executable. Userspace can assume that the remainer of the page after // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) } - m.BrkSetup(ctx, e) + args.MemoryManager.BrkSetup(ctx, e) // Allocate our stack. - stack, err := allocStack(ctx, m, ac) + stack, err := allocStack(ctx, args.MemoryManager, ac) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux()) } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(filename) + execfn, err := stack.Push(args.Filename) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } @@ -308,6 +368,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())}, + // The conditions that require AT_SECURE = 1 never arise. See + // kernel.Task.updateCredsForExecLocked. + arch.AuxEntry{linux.AT_SECURE, 0}, arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC}, arch.AuxEntry{linux.AT_EXECFN, execfn}, arch.AuxEntry{linux.AT_RANDOM, random}, @@ -316,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r }...) auxv = append(auxv, extraAuxv...) - sl, err := stack.Load(argv, envv, auxv) + sl, err := stack.Load(newArgv, args.Envv, auxv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux()) } + m := args.MemoryManager m.SetArgvStart(sl.ArgvStart) m.SetArgvEnd(sl.ArgvEnd) m.SetEnvvStart(sl.EnvvStart) @@ -331,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) - name := path.Base(filename) + name := path.Base(args.Filename) if len(name) > linux.TASK_COMM_LEN-1 { name = name[:linux.TASK_COMM_LEN-1] } diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index ada28aea3..df8a81907 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -268,6 +268,8 @@ func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, er // some applications may not be able to handle multiple [vdso] // hints. vdso: mm.NewSpecialMappable("", mfp, vdso), + os: info.os, + arch: info.arch, phdrs: info.phdrs, }, nil } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 29c14ec56..112794e9c 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "mappable_range", @@ -40,7 +41,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", - "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/platform", "//pkg/sentry/usermem", diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 03b99aaea..16a722a13 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -18,7 +18,6 @@ package memmap import ( "fmt" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" @@ -235,8 +234,11 @@ type InvalidateOpts struct { // coincidental; fs.File implements MappingIdentity, and some // fs.InodeOperations implement Mappable.) type MappingIdentity interface { - // MappingIdentity is reference-counted. - refs.RefCounter + // IncRef increments the MappingIdentity's reference count. + IncRef() + + // DecRef decrements the MappingIdentity's reference count. + DecRef() // MappedName returns the application-visible name shown in // /proc/[pid]/maps. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 072745a08..839931f67 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "file_refcount_set", @@ -117,9 +118,9 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/tcpip/buffer", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index f350e0109..58a5c186d 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -44,7 +44,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // MemoryManager implements a virtual address space. @@ -82,7 +82,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu gvsync.DowngradableRWMutex `state:"nosave"` + mappingMu syncutil.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +125,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu gvsync.DowngradableRWMutex `state:"nosave"` + activeMu syncutil.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 858f895f2..f404107af 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "evictable_range", diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index 1effc7735..aafce1d00 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -16,6 +16,7 @@ package pgalloc import ( "bytes" + "context" "fmt" "io" "runtime" @@ -29,7 +30,7 @@ import ( ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // Save metadata. - if err := state.Save(w, &f.fileSize, nil); err != nil { + if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { return err } - if err := state.Save(w, &f.usage, nil); err != nil { + if err := state.Save(ctx, w, &f.usage, nil); err != nil { return err } @@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { // Load metadata. - if err := state.Load(r, &f.fileSize, nil); err != nil { + if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(r, &f.usage, nil); err != nil { + if err := state.Load(ctx, r, &f.usage, nil); err != nil { return err } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 9aa6ec507..157bffa81 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "file_range", out = "file_range.go", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index eeb634644..b6d008dbe 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 2046df525..f3afd98da 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -11,17 +12,26 @@ go_library( "bluepill_amd64.go", "bluepill_amd64.s", "bluepill_amd64_unsafe.go", + "bluepill_arm64.go", + "bluepill_arm64.s", + "bluepill_arm64_unsafe.go", "bluepill_fault.go", "bluepill_unsafe.go", "context.go", - "filters.go", + "filters_amd64.go", + "filters_arm64.go", "kvm.go", "kvm_amd64.go", "kvm_amd64_unsafe.go", + "kvm_arm64.go", + "kvm_arm64_unsafe.go", "kvm_const.go", + "kvm_const_arm64.go", "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", + "machine_arm64_unsafe.go", "machine_unsafe.go", "physical_map.go", "physical_map_amd64.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 043de51b3..30dbb74d6 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -20,6 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" ) @@ -36,6 +37,18 @@ func sighandler() func dieTrampoline() var ( + // bounceSignal is the signal used for bouncing KVM. + // + // We use SIGCHLD because it is not masked by the runtime, and + // it will be ignored properly by other parts of the kernel. + bounceSignal = syscall.SIGCHLD + + // bounceSignalMask has only bounceSignal set. + bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) + + // bounce is the interrupt vector used to return to the kernel. + bounce = uint32(ring0.VirtualizationException) + // savedHandler is a pointer to the previous handler. // // This is called by bluepillHandler. @@ -45,6 +58,13 @@ var ( dieTrampolineAddr uintptr ) +// redpill invokes a syscall with -1. +// +//go:nosplit +func redpill() { + syscall.RawSyscall(^uintptr(0), 0, 0, 0) +} + // dieHandler is called by dieTrampoline. // //go:nosplit @@ -73,8 +93,8 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { - panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err)) + if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 421c88220..133c2203d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -24,26 +24,10 @@ import ( ) var ( - // bounceSignal is the signal used for bouncing KVM. - // - // We use SIGCHLD because it is not masked by the runtime, and - // it will be ignored properly by other parts of the kernel. - bounceSignal = syscall.SIGCHLD - - // bounceSignalMask has only bounceSignal set. - bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) - - // bounce is the interrupt vector used to return to the kernel. - bounce = uint32(ring0.VirtualizationException) + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGSEGV ) -// redpill on amd64 invokes a syscall with -1. -// -//go:nosplit -func redpill() { - syscall.RawSyscall(^uintptr(0), 0, 0, 0) -} - // bluepillArchEnter is called during bluepillEnter. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 9d8af143e..a63a6a071 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -23,13 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) -// bluepillArchContext returns the arch-specific context. -// -//go:nosplit -func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { - return &((*arch.UContext64)(context).MContext) -} - // dieArchSetup initializes the state for dieTrampoline. // // The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go new file mode 100644 index 000000000..552341721 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -0,0 +1,79 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" +) + +var ( + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGILL +) + +// bluepillArchEnter is called during bluepillEnter. +// +//go:nosplit +func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { + c = vCPUPtr(uintptr(context.Regs[8])) + regs := c.CPU.Registers() + regs.Regs = context.Regs + regs.Sp = context.Sp + regs.Pc = context.Pc + regs.Pstate = context.Pstate + regs.Pstate &^= uint64(ring0.KernelFlagsClear) + regs.Pstate |= ring0.KernelFlagsSet + return +} + +// bluepillArchExit is called during bluepillEnter. +// +//go:nosplit +func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { + regs := c.CPU.Registers() + context.Regs = regs.Regs + context.Sp = regs.Sp + context.Pc = regs.Pc + context.Pstate = regs.Pstate + context.Pstate &^= uint64(ring0.UserFlagsClear) + context.Pstate |= ring0.UserFlagsSet +} + +// KernelSyscall handles kernel syscalls. +// +//go:nosplit +func (c *vCPU) KernelSyscall() { + regs := c.Registers() + if regs.Regs[8] != ^uint64(0) { + regs.Pc -= 4 // Rewind. + } + ring0.Halt() +} + +// KernelException handles kernel exceptions. +// +//go:nosplit +func (c *vCPU) KernelException(vector ring0.Vector) { + regs := c.Registers() + if vector == ring0.Vector(bounce) { + regs.Pc = 0 + } + ring0.Halt() +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s new file mode 100644 index 000000000..c61700892 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -0,0 +1,87 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// VCPU_CPU is the location of the CPU in the vCPU struct. +// +// This is guaranteed to be zero. +#define VCPU_CPU 0x0 + +// CPU_SELF is the self reference in ring0's percpu. +// +// This is guaranteed to be zero. +#define CPU_SELF 0x0 + +// Context offsets. +// +// Only limited use of the context is done in the assembly stub below, most is +// done in the Go handlers. +#define SIGINFO_SIGNO 0x0 +#define CONTEXT_PC 0x1B8 +#define CONTEXT_R0 0xB8 + +// See bluepill.go. +TEXT ·bluepill(SB),NOSPLIT,$0 +begin: + MOVD vcpu+0(FP), R8 + MOVD $VCPU_CPU(R8), R9 + ORR $0xffff000000000000, R9, R9 + // Trigger sigill. + // In ring0.Start(), the value of R8 will be stored into tpidr_el1. + // When the context was loaded into vcpu successfully, + // we will check if the value of R10 and R9 are the same. + WORD $0xd538d08a // MRS TPIDR_EL1, R10 +check_vcpu: + CMP R10, R9 + BEQ right_vCPU +wrong_vcpu: + CALL ·redpill(SB) + B begin +right_vCPU: + RET + +// sighandler: see bluepill.go for documentation. +// +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sighandler(SB),NOSPLIT,$0 + // si_signo should be sigill. + MOVD SIGINFO_SIGNO(R1), R7 + CMPW $4, R7 + BNE fallback + + MOVD CONTEXT_PC(R2), R7 + CMPW $0, R7 + BEQ fallback + + MOVD R2, 8(RSP) + BL ·bluepillHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. +TEXT ·dieTrampoline(SB),NOSPLIT,$0 + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. + MOVD R9, 8(RSP) + BL ·dieHandler(SB) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go new file mode 100644 index 000000000..e5fac0d6a --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +//go:nosplit +func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. +} diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 7e8e9f42a..9add7c944 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -23,6 +23,8 @@ import ( "sync/atomic" "syscall" "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) //go:linkname throw runtime.throw @@ -49,6 +51,13 @@ func uintptrValue(addr *byte) uintptr { return (uintptr)(unsafe.Pointer(addr)) } +// bluepillArchContext returns the UContext64. +// +//go:nosplit +func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { + return &((*arch.UContext64)(context).MContext) +} + // bluepillHandler is called from the signal stub. // // The world may be stopped while this is executing, and it executes on the @@ -80,13 +89,17 @@ func bluepillHandler(context unsafe.Pointer) { // interrupted KVM. Since we're in a signal handler // currently, all signals are masked and the signal // must have been delivered directly to this thread. + timeout := syscall.Timespec{} sig, _, errno := syscall.RawSyscall6( syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), - 0, // siginfo. - 0, // timeout. - 8, // sigset size. + 0, // siginfo. + uintptr(unsafe.Pointer(&timeout)), // timeout. + 8, // sigset size. 0, 0) + if errno == syscall.EAGAIN { + continue + } if errno != 0 { throw("error waiting for pending signal") } @@ -162,7 +175,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/filters.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..7d949f1dd 100644 --- a/pkg/sentry/platform/kvm/filters.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go new file mode 100644 index 000000000..9245d07c2 --- /dev/null +++ b/pkg/sentry/platform/kvm/filters_arm64.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +// SyscallFilters returns syscalls made exclusively by the KVM platform. +func (*KVM) SyscallFilters() seccomp.SyscallRules { + return seccomp.SyscallRules{ + syscall.SYS_IOCTL: {}, + syscall.SYS_MMAP: {}, + syscall.SYS_RT_SIGSUSPEND: {}, + syscall.SYS_RT_SIGTIMEDWAIT: {}, + 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host. + } +} diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ee4cd2f4d..f2c2c059e 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -21,7 +21,6 @@ import ( "sync" "syscall" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" @@ -56,9 +55,7 @@ func New(deviceFile *os.File) (*KVM, error) { // Ensure global initialization is done. globalOnce.Do(func() { - physicalInit() - globalErr = updateSystemValues(int(fd)) - ring0.Init(cpuid.HostFeatureSet()) + globalErr = updateGlobalOnce(int(fd)) }) if globalErr != nil { return nil, globalErr diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index 5d8ef4761..c5a6f9c7d 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -17,6 +17,7 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) @@ -211,3 +212,11 @@ type cpuidEntries struct { _ uint32 entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry } + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + ring0.Init(cpuid.HostFeatureSet()) + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go new file mode 100644 index 000000000..2319c86d3 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" +) + +// userMemoryRegion is a region of physical memory. +// +// This mirrors kvm_memory_region. +type userMemoryRegion struct { + slot uint32 + flags uint32 + guestPhysAddr uint64 + memorySize uint64 + userspaceAddr uint64 +} + +type kvmOneReg struct { + id uint64 + addr uint64 +} + +const KVM_NR_SPSR = 5 + +type userFpsimdState struct { + vregs [64]uint64 + fpsr uint32 + fpcr uint32 + reserved [2]uint32 +} + +type userRegs struct { + Regs syscall.PtraceRegs + sp_el1 uint64 + elr_el1 uint64 + spsr [KVM_NR_SPSR]uint64 + fpRegs userFpsimdState +} + +// runData is the run structure. This may be mapped for synchronous register +// access (although that doesn't appear to be supported by my kernel at least). +// +// This mirrors kvm_run. +type runData struct { + requestInterruptWindow uint8 + _ [7]uint8 + + exitReason uint32 + readyForInterruptInjection uint8 + ifFlag uint8 + _ [2]uint8 + + cr8 uint64 + apicBase uint64 + + // This is the union data for exits. Interpretation depends entirely on + // the exitReason above (see vCPU code for more information). + data [32]uint64 +} + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + updateVectorTable() + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go new file mode 100644 index 000000000..6531bae1d --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "syscall" +) + +var ( + runDataSize int +) + +func updateSystemValues(fd int) error { + // Extract the mmap size. + sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0) + if errno != 0 { + return fmt.Errorf("getting VCPU mmap size: %v", errno) + } + // Save the data. + runDataSize = int(sz) + + // Success. + return nil +} diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..1d5c77ff4 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -49,11 +49,13 @@ const ( _KVM_EXIT_SHUTDOWN = 0x8 _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 + _KVM_EXIT_SYSTEM_EVENT = 0x18 ) // KVM capability options. const ( - _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 ) // KVM limits. @@ -62,3 +64,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go new file mode 100644 index 000000000..5a74c6e36 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -0,0 +1,132 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// KVM ioctls for Arm64. +const ( + _KVM_GET_ONE_REG = 0x4010aeab + _KVM_SET_ONE_REG = 0x4010aeac + + _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf + _KVM_ARM_VCPU_INIT = 0x4020aeae + _KVM_ARM64_REGS_PSTATE = 0x6030000000100042 + _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044 + _KVM_ARM64_REGS_R0 = 0x6030000000100000 + _KVM_ARM64_REGS_R1 = 0x6030000000100002 + _KVM_ARM64_REGS_R2 = 0x6030000000100004 + _KVM_ARM64_REGS_R3 = 0x6030000000100006 + _KVM_ARM64_REGS_R8 = 0x6030000000100010 + _KVM_ARM64_REGS_R18 = 0x6030000000100024 + _KVM_ARM64_REGS_PC = 0x6030000000100040 + _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510 + _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102 + _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100 + _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101 + _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 + _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 + _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 +) + +// Arm64: Architectural Feature Access Control Register EL1. +const ( + _FPEN_NOTRAP = 0x3 + _FPEN_SHIFT = 0x20 +) + +// Arm64: System Control Register EL1. +const ( + _SCTLR_M = 1 << 0 + _SCTLR_C = 1 << 2 + _SCTLR_I = 1 << 12 +) + +// Arm64: Translation Control Register EL1. +const ( + _TCR_IPS_40BITS = 2 << 32 // PA=40 + _TCR_IPS_48BITS = 5 << 32 // PA=48 + + _TCR_T0SZ_OFFSET = 0 + _TCR_T1SZ_OFFSET = 16 + _TCR_IRGN0_SHIFT = 8 + _TCR_IRGN1_SHIFT = 24 + _TCR_ORGN0_SHIFT = 10 + _TCR_ORGN1_SHIFT = 26 + _TCR_SH0_SHIFT = 12 + _TCR_SH1_SHIFT = 28 + _TCR_TG0_SHIFT = 14 + _TCR_TG1_SHIFT = 30 + + _TCR_T0SZ_VA48 = 64 - 48 // VA=48 + _TCR_T1SZ_VA48 = 64 - 48 // VA=48 + + _TCR_ASID16 = 1 << 36 + _TCR_TBI0 = 1 << 37 + + _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET) + + _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K + _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K + + _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT + + _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K + + _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT + _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT + _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA + + _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT + _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT + + _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA + + _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT) + + _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA +) + +// Arm64: Memory Attribute Indirection Register EL1. +const ( + _MT_DEVICE_nGnRnE = 0 + _MT_DEVICE_nGnRE = 1 + _MT_DEVICE_GRE = 2 + _MT_NORMAL_NC = 3 + _MT_NORMAL = 4 + _MT_NORMAL_WT = 5 + _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8) +) + +const ( + _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state + _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2 +) + +// Arm64: Exception Syndrome Register EL1. +const ( + _ESR_ELx_FSC = 0x3F + + _ESR_SEGV_MAPERR_L0 = 0x4 + _ESR_SEGV_MAPERR_L1 = 0x5 + _ESR_SEGV_MAPERR_L2 = 0x6 + _ESR_SEGV_MAPERR_L3 = 0x7 + + _ESR_SEGV_ACCERR_L1 = 0x9 + _ESR_SEGV_ACCERR_L2 = 0xa + _ESR_SEGV_ACCERR_L3 = 0xb + + _ESR_SEGV_PEMERR_L1 = 0xd + _ESR_SEGV_PEMERR_L2 = 0xe + _ESR_SEGV_PEMERR_L3 = 0xf +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..7156c245f 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. @@ -159,3 +135,43 @@ func (c *vCPU) setSignalMask() error { } return nil } + +// setUserRegisters sets user registers in the vCPU. +func (c *vCPU) setUserRegisters(uregs *userRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) + } + return nil +} + +// getUserRegisters reloads user registers in the vCPU. +// +// This is safe to call from a nosplit context. +// +//go:nosplit +func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return errno + } + return 0 +} + +// setSystemRegisters sets system registers. +func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return fmt.Errorf("error setting system registers: %v", errno) + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..7ae47f291 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,183 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +type vCPUArchState struct { + // PCIDs is the set of PCIDs for this vCPU. + // + // This starts above fixedKernelPCID. + PCIDs *pagetables.PCIDs +} + +const ( + // fixedKernelPCID is a fixed kernel PCID used for the kernel page + // tables. We must start allocating user PCIDs above this in order to + // avoid any conflict (see below). + fixedKernelPCID = 1 + + // poolPCIDs is the number of PCIDs to record in the database. As this + // grows, assignment can take longer, since it is a simple linear scan. + // Beyond a relatively small number, there are likely few perform + // benefits, since the TLB has likely long since lost any translations + // from more than a few PCIDs past. + poolPCIDs = 8 +) + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} + +// dropPageTables drops cached page table entries. +func (m *machine) dropPageTables(pt *pagetables.PageTables) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear from all PCIDs. + for _, c := range m.vCPUs { + c.PCIDs.Drop(pt) + } +} + +// nonCanonical generates a canonical address return. +// +//go:nosplit +func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + *info = arch.SignalInfo{ + Signo: signal, + Code: arch.SignalInfoKernel, + } + info.SetAddr(addr) // Include address. + return usermem.NoAccess, platform.ErrContextSignal +} + +// fault generates an appropriate fault return. +// +//go:nosplit +func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + faultAddr := c.GetFaultAddr() + code, user := c.ErrorCode() + + // Reset the pointed SignalInfo. + *info = arch.SignalInfo{Signo: signal} + info.SetAddr(uint64(faultAddr)) + + read := true + write := false + execute := true + + ret := code & _ESR_ELx_FSC + switch ret { + case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3: + info.Code = 1 //SEGV_MAPERR + read = false + write = true + execute = false + case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3: + info.Code = 2 // SEGV_ACCERR. + read = true + write = false + execute = false + default: + info.Code = 2 + } + + if !user { + read = true + write = false + execute = true + + } + accessType := usermem.AccessType{ + Read: read, + Write: write, + Execute: execute, + } + + return accessType, platform.ErrContextSignal +} + +// retryInGuest runs the given function in guest mode. +// +// If the function does not complete in guest mode (due to execution of a +// system call due to a GC stall, for example), then it will be retried. The +// given function must be idempotent as a result of the retry mechanism. +func (m *machine) retryInGuest(fn func()) { + c := m.Get() + defer m.Put(c) + for { + c.ClearErrorCode() // See below. + bluepill(c) // Force guest mode. + fn() // Execute the given function. + _, user := c.ErrorCode() + if user { + // If user is set, then we haven't bailed back to host + // mode via a kernel exception or system call. We + // consider the full function to have executed in guest + // mode and we can return. + break + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go new file mode 100644 index 000000000..3f2f97a6b --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "reflect" + "sync/atomic" + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: 0, + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + +type kvmVcpuInit struct { + target uint32 + features [7]uint32 +} + +var vcpuInit kvmVcpuInit + +// initArchState initializes architecture-specific state. +func (m *machine) initArchState() error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_ARM_PREFERRED_TARGET, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) + } + return nil +} + +func getPageWithReflect(p uintptr) []byte { + return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()] +} + +// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. +// +// According to the design documentation of Arm64, +// the start address of exception vector table should be 11-bits aligned. +// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S +// But, we can't align a function's start address to a specific address by using golang. +// We have raised this question in golang community: +// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I +// This function will be removed when golang supports this feature. +// +// There are 2 jobs were implemented in this function: +// 1, move the start address of exception vector table into the specific address. +// 2, modify the offset of each instruction. +func updateVectorTable() { + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + page := getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + // Move exception-vector-table into the specific address. + var entry *uint32 + var entryFrom *uint32 + for i := 1; i <= 0x800; i++ { + entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i))) + entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i))) + *entry = *entryFrom + } + + // The offset from the address of each unconditionally branch is changed. + // We should modify the offset of each instruction. + nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780} + for _, num := range nums { + entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num))) + *entry = *entry - (uint32)(offset/4) + } + + page = getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } +} + +// initArchState initializes architecture-specific state. +func (c *vCPU) initArchState() error { + var ( + reg kvmOneReg + data uint64 + regGet kvmOneReg + dataGet uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer()) + + vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_ARM_VCPU_INIT, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno)) + } + + // cpacr_el1 + reg.id = _KVM_ARM64_REGS_CPACR_EL1 + data = (_FPEN_NOTRAP << _FPEN_SHIFT) + if err := c.setOneRegister(®); err != nil { + return err + } + + // sctlr_el1 + regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.getOneRegister(®Get); err != nil { + return err + } + + dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) + data = dataGet + reg.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // tcr_el1 + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + reg.id = _KVM_ARM64_REGS_TCR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // mair_el1 + data = _MT_EL1_INIT + reg.id = _KVM_ARM64_REGS_MAIR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // ttbr0_el1 + data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR0_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + c.SetTtbr0Kvm(uintptr(data)) + + // ttbr1_el1 + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR1_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // sp_el1 + data = c.CPU.StackTop() + reg.id = _KVM_ARM64_REGS_SP_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // pc + reg.id = _KVM_ARM64_REGS_PC + data = uint64(reflect.ValueOf(ring0.Start).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // r8 + reg.id = _KVM_ARM64_REGS_R8 + data = uint64(reflect.ValueOf(&c.CPU).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // vbar_el1 + reg.id = _KVM_ARM64_REGS_VBAR_EL1 + + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + data = uint64(ring0.KernelStartAddress | toLocation) + if err := c.setOneRegister(®); err != nil { + return err + } + + data = ring0.PsrDefaultSet | ring0.KernelFlagsSet + reg.id = _KVM_ARM64_REGS_PSTATE + if err := c.setOneRegister(®); err != nil { + return err + } + + return nil +} + +//go:nosplit +func (c *vCPU) loadSegments(tid uint64) { + // TODO(gvisor.dev/issue/1238): TLS is not supported. + // Get TLS from tpidr_el0. + atomic.StoreUint64(&c.tid, tid) +} + +func (c *vCPU) setOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +func (c *vCPU) getOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +// setCPUID sets the CPUID to be used by the guest. +func (c *vCPU) setCPUID() error { + return nil +} + +// setSystemTime sets the TSC for the vCPU. +func (c *vCPU) setSystemTime() error { + return nil +} + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + + return nil +} + +// SwitchToUser unpacks architectural-details. +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { + // Check for canonical addresses. + if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { + return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) + } else if !ring0.IsCanonical(regs.Sp) { + return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + } + + var vector ring0.Vector + ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) + c.SetTtbr0App(uintptr(ttbr0App)) + + // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64. + // The Arm64 user-mode execution state consists of: + // x0-x30 + // PC, SP, PSTATE + // V0-V31: 32 128-bit registers for floating point, and simd + // FPSR + // TPIDR_EL0, used for TLS + appRegs := switchOpts.Registers + c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs))) + + entersyscall() + bluepill(c) + vector = c.CPU.SwitchToUser(switchOpts) + exitsyscall() + + switch vector { + case ring0.Syscall: + // Fast path: system call executed. + return usermem.NoAccess, nil + + case ring0.PageFault: + return c.fault(int32(syscall.SIGSEGV), info) + case 0xaa: + return usermem.NoAccess, nil + default: + return usermem.NoAccess, platform.ErrContextSignal + } + +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 405e00292..f04be2ab5 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( @@ -63,46 +87,6 @@ func unmapRunData(r *runData) error { return nil } -// setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) - } - return nil -} - -// getUserRegisters reloads user registers in the vCPU. -// -// This is safe to call from a nosplit context. -// -//go:nosplit -func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_GET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return errno - } - return 0 -} - -// setSystemRegisters sets system registers. -func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_SREGS, - uintptr(unsafe.Pointer(sregs))); errno != 0 { - return fmt.Errorf("error setting system registers: %v", errno) - } - return nil -} - // atomicAddressSpace is an atomic address space pointer. type atomicAddressSpace struct { pointer unsafe.Pointer diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index 77a449a8b..b0e45f159 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -9,6 +9,8 @@ go_library( "testutil.go", "testutil_amd64.go", "testutil_amd64.s", + "testutil_arm64.go", + "testutil_arm64.s", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go index 6cf2359a3..5c1efa0fd 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ b/pkg/sentry/platform/kvm/testutil/testutil.go @@ -41,9 +41,6 @@ func TwiddleRegsFault() // TwiddleRegsSyscall twiddles registers then executes a syscall. func TwiddleRegsSyscall() -// TwiddleSegments reads segments into known registers. -func TwiddleSegments() - // FloatingPointWorks is a floating point test. // // It returns true or false. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go index 203d71528..4c108abbf 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go @@ -21,6 +21,9 @@ import ( "syscall" ) +// TwiddleSegments reads segments into known registers. +func TwiddleSegments() + // SetTestTarget sets the rip appropriately. func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go new file mode 100644 index 000000000..40b2e4acc --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -0,0 +1,59 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package testutil + +import ( + "fmt" + "reflect" + "syscall" +) + +// SetTestTarget sets the rip appropriately. +func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { + regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +} + +// SetTouchTarget sets rax appropriately. +func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { + if target != nil { + regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) + } else { + regs.Regs[8] = 0 + } +} + +// RewindSyscall rewinds a syscall RIP. +func RewindSyscall(regs *syscall.PtraceRegs) { + regs.Pc -= 4 +} + +// SetTestRegs initializes registers to known values. +func SetTestRegs(regs *syscall.PtraceRegs) { + for i := 0; i <= 30; i++ { + regs.Regs[i] = uint64(i) + 1 + } +} + +// CheckTestRegs checks that registers were twiddled per TwiddleRegs. +func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { + for i := 0; i <= 30; i++ { + if need := ^uint64(i + 1); regs.Regs[i] != need { + err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) + } + } + return +} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s new file mode 100644 index 000000000..0bebee852 --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -0,0 +1,106 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +// test_util_arm64.s provides ARM64 test functions. + +#include "funcdata.h" +#include "textflag.h" + +#define SYS_GETPID 172 + +// This function simulates the getpid syscall. +TEXT ·Getpid(SB),NOSPLIT,$0 + NO_LOCAL_POINTERS + MOVD $SYS_GETPID, R8 + SVC + RET + +TEXT ·Touch(SB),NOSPLIT,$0 +start: + MOVD 0(R8), R1 + MOVD $SYS_GETPID, R8 // getpid + SVC + B start + +TEXT ·HaltLoop(SB),NOSPLIT,$0 +start: + HLT + B start + +// This function simulates a loop of syscall. +TEXT ·SyscallLoop(SB),NOSPLIT,$0 +start: + SVC + B start + +TEXT ·SpinLoop(SB),NOSPLIT,$0 +start: + B start + +TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 + NO_LOCAL_POINTERS + FMOVD $(9.9), F0 + MOVD $SYS_GETPID, R8 // getpid + SVC + FMOVD $(9.9), F1 + FCMPD F0, F1 + BNE isNaN + MOVD $1, R0 + MOVD R0, ret+0(FP) + RET +isNaN: + MOVD $0, ret+0(FP) + RET + +// MVN: bitwise logical NOT +// This case simulates an application that modified R0-R30. +#define TWIDDLE_REGS() \ + MVN R0, R0; \ + MVN R1, R1; \ + MVN R2, R2; \ + MVN R3, R3; \ + MVN R4, R4; \ + MVN R5, R5; \ + MVN R6, R6; \ + MVN R7, R7; \ + MVN R8, R8; \ + MVN R9, R9; \ + MVN R10, R10; \ + MVN R11, R11; \ + MVN R12, R12; \ + MVN R13, R13; \ + MVN R14, R14; \ + MVN R15, R15; \ + MVN R16, R16; \ + MVN R17, R17; \ + MVN R18_PLATFORM, R18_PLATFORM; \ + MVN R19, R19; \ + MVN R20, R20; \ + MVN R21, R21; \ + MVN R22, R22; \ + MVN R23, R23; \ + MVN R24, R24; \ + MVN R25, R25; \ + MVN R26, R26; \ + MVN R27, R27; \ + MVN g, g; \ + MVN R29, R29; \ + MVN R30, R30; + +TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 + TWIDDLE_REGS() + SVC + RET // never reached diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index ebcc8c098..0df8cfa0f 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/procid", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/hostcpu", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 47957bb3b..72c7ec564 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) { cpu: ^uint32(0), }, nil } + +// getEventMessage retrieves a message about the ptrace event that just happened. +func (t *thread) getEventMessage() (uintptr, error) { + var msg uintptr + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETEVENTMSG, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(&msg)), + 0, 0) + if errno != 0 { + return msg, errno + } + return msg, nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 6bf7cd097..ddb1f41e3 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread { // attach attaches to the thread. func (t *thread) attach() { - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("unable to attach: %v", errno)) } @@ -327,6 +327,20 @@ func (t *thread) dumpAndPanic(message string) { panic(message) } +func (t *thread) unexpectedStubExit() { + msg, err := t.getEventMessage() + status := syscall.WaitStatus(msg) + if status.Signaled() && status.Signal() == syscall.SIGKILL { + // SIGKILL can be only sent by an user or OOM-killer. In both + // these cases, we don't need to panic. There is no reasons to + // think that something wrong in gVisor. + log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid) + pid := os.Getpid() + syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL)) + } + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) +} + // wait waits for a stop event. // // Precondition: outcome is a valid waitOutcome. @@ -355,7 +369,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - t.dumpAndPanic("wait failed: the process exited") + t.unexpectedStubExit() } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) @@ -416,7 +430,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { for { // Execute the syscall instruction. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -426,12 +440,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { break } else { // Some other signal caused a thread stop; ignore. + if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD { + log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig) + } continue } } // Complete the actual system call. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -522,17 +539,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { for { // Start running until the next system call. if isSingleStepping(regs) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU_SINGLESTEP, - uintptr(t.tid), 0); errno != 0 { + uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } } else { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall6( syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, - uintptr(t.tid), 0); errno != 0 { + uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace sysemu failed: %v", errno)) } } diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 4649a94a7..606dc2b1d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -21,6 +21,8 @@ import ( "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -143,3 +145,49 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.R15 = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Rip = signalInfo.Addr() + regs.Rsp -= 8 + } +} + +// enableCpuidFault enables cpuid-faulting. +// +// This may fail on older kernels or hardware, so we just disregard the result. +// Host CPUID will be enabled. +// +// This is safe to call in an afterFork context. +// +//go:nosplit +func enableCpuidFault() { + syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return append(rules, seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }) +} diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index bec884ba5..62a686ee7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -17,8 +17,12 @@ package ptrace import ( + "fmt" + "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -37,7 +41,7 @@ const ( // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. -func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) { +func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { } // createSyscallRegs sets up syscall registers. @@ -124,3 +128,36 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.Regs[7] = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Pc = signalInfo.Addr() + regs.Sp -= 8 + } +} + +// Noop on arm64. +// +//go:nosplit +func enableCpuidFault() { +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return rules +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index f09b0b3d0..cf13ea5e4 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -20,6 +20,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -53,7 +54,7 @@ func probeSeccomp() bool { for { // Attempt an emulation. - if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 { + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) } @@ -77,27 +78,6 @@ func probeSeccomp() bool { } } -// patchSignalInfo patches the signal info to account for hitting the seccomp -// filters from vsyscall emulation, specified below. We allow for SIGSYS as a -// synchronous trap, but patch the structure to appear like a SIGSEGV with the -// Rip as the faulting address. -// -// Note that this should only be called after verifying that the signalInfo has -// been generated by the kernel. -func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { - if linux.Signal(signalInfo.Signo) == linux.SIGSYS { - signalInfo.Signo = int32(linux.SIGSEGV) - - // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered - // with the si_call_addr field pointing to the current RIP. This field - // aligns with the si_addr field for a SIGSEGV, so we don't need to touch - // anything there. We do need to unwind emulation however, so we set the - // instruction pointer to the faulting value, and "unpop" the stack. - regs.Rip = signalInfo.Addr() - regs.Rsp -= 8 - } -} - // createStub creates a fresh stub processes. // // Precondition: the runtime OS thread must be locked. @@ -129,6 +109,9 @@ func createStub() (*thread, error) { // transitively) will be killed as well. It's simply not possible to // safely handle a single stub getting killed: the exact state of // execution is unknown and not recoverable. + // + // In addition, we set the PTRACE_O_TRACEEXIT option to log more + // information about a stub process when it receives a fatal signal. return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction) } @@ -146,7 +129,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_GETTIMEOFDAY: {}, syscall.SYS_TIME: {}, - 309: {}, // SYS_GETCPU. + unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. }, Action: linux.SECCOMP_RET_TRAP, Vsyscall: true, @@ -170,10 +153,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the initial process creation. syscall.SYS_WAIT4: {}, - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, - }, - syscall.SYS_EXIT: {}, + syscall.SYS_EXIT: {}, // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ @@ -193,6 +173,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }, Action: linux.SECCOMP_RET_ALLOW, }) + + rules = appendArchSeccompRules(rules) } instrs, err := seccomp.BuildProgram(rules, defaultAction) if err != nil { @@ -264,9 +246,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0) } - // Enable cpuid-faulting; this may fail on older kernels or hardware, - // so we just disregard the result. Host CPUID will be enabled. - syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0) + // Enable cpuid-faulting. + enableCpuidFault() // Call the stub; should not return. stubCall(stubStart, ppid) diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index de6783fb0..2e6fbe488 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -25,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/hostcpu" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, @@ -49,20 +50,6 @@ func unmaskAllSignals() syscall.Errno { return errno } -// getCPU gets the current CPU. -// -// Precondition: the current runtime thread should be locked. -func getCPU() (uint32, error) { - var cpu uintptr - if _, _, errno := syscall.RawSyscall( - unix.SYS_GETCPU, - uintptr(unsafe.Pointer(&cpu)), - 0, 0); errno != 0 { - return 0, errno - } - return uint32(cpu), nil -} - // setCPU sets the CPU affinity. func (t *thread) setCPU(cpu uint32) error { mask := maskPool.Get().([]uintptr) @@ -93,10 +80,8 @@ func (t *thread) setCPU(cpu uint32) error { // // Precondition: the current runtime thread should be locked. func (t *thread) bind() { - currentCPU, err := getCPU() - if err != nil { - return - } + currentCPU := hostcpu.GetCPU() + if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { // Set the affinity on the thread and save the CPU for next // round; we don't expect CPUs to bounce around too frequently. diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index b80a3604d..2ae6b9f9d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 8ed6c7652..87f4552b5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,11 +1,10 @@ load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - go_template( - name = "defs", + name = "defs_amd64", srcs = [ "defs.go", "defs_amd64.go", @@ -15,11 +14,29 @@ go_template( visibility = [":__subpackages__"], ) +go_template( + name = "defs_arm64", + srcs = [ + "aarch64.go", + "defs.go", + "defs_arm64.go", + "offsets_arm64.go", + ], + visibility = [":__subpackages__"], +) + +go_template_instance( + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", + package = "ring0", + template = ":defs_amd64", +) + go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", package = "ring0", - template = ":defs", + template = ":defs_arm64", ) genrule( @@ -30,17 +47,31 @@ genrule( tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) +genrule( + name = "entry_impl_arm64", + srcs = ["entry_arm64.s"], + outs = ["entry_impl_arm64.s"], + cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) + go_library( name = "ring0", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "entry_amd64.go", + "entry_arm64.go", "entry_impl_amd64.s", + "entry_impl_arm64.s", "kernel.go", "kernel_amd64.go", + "kernel_arm64.go", "kernel_unsafe.go", "lib_amd64.go", "lib_amd64.s", + "lib_arm64.go", + "lib_arm64.s", "ring0.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go new file mode 100644 index 000000000..6b078cd1e --- /dev/null +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// Useful bits. +const ( + _PGD_PGT_BASE = 0x1000 + _PGD_PGT_SIZE = 0x1000 + _PUD_PGT_BASE = 0x2000 + _PUD_PGT_SIZE = 0x1000 + _PMD_PGT_BASE = 0x3000 + _PMD_PGT_SIZE = 0x4000 + _PTE_PGT_BASE = 0x7000 + _PTE_PGT_SIZE = 0x1000 + + _PSR_MODE_EL0t = 0x0 + _PSR_MODE_EL1t = 0x4 + _PSR_MODE_EL1h = 0x5 + _PSR_EL_MASK = 0xf + + _PSR_D_BIT = 0x200 + _PSR_A_BIT = 0x100 + _PSR_I_BIT = 0x80 + _PSR_F_BIT = 0x40 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t + + KernelFlagsClear = _PSR_EL_MASK + UserFlagsClear = _PSR_EL_MASK + + PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + El1SyncInvalid = iota + El1IrqInvalid + El1FiqInvalid + El1ErrorInvalid + El1Sync + El1Irq + El1Fiq + El1Error + El0Sync + El0Irq + El0Fiq + El0Error + El0Sync_invalid + El0Irq_invalid + El0Fiq_invalid + El0Error_invalid + El1Sync_da + El1Sync_ia + El1Sync_sp_pc + El1Sync_undef + El1Sync_dbg + El1Sync_inv + El0Sync_svc + El0Sync_da + El0Sync_ia + El0Sync_fpsimd_acc + El0Sync_sve_acc + El0Sync_sys + El0Sync_sp_pc + El0Sync_undef + El0Sync_dbg + El0Sync_inv + VirtualizationException + _NR_INTERRUPTS +) + +// System call vectors. +const ( + Syscall Vector = El0Sync_svc + PageFault Vector = El0Sync_da +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +func VirtualAddressBits() uint32 { + return 48 +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +func PhysicalAddressBits() uint32 { + return 40 +} diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 076063f85..3f094c2a7 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -20,17 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - // Kernel is a global kernel object. // // This contains global state, shared by multiple CPUs. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 7206322b1..10dbd381f 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -20,6 +20,17 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + // Segment indices and Selectors. const ( // Index into GDT array. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go new file mode 100644 index 000000000..dc0eeec01 --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -0,0 +1,136 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits()) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [512]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // faultAddr is the value of far_el1. + faultAddr uintptr + + // ttbr0Kvm is the value of ttbr0_el1 for sentry. + ttbr0Kvm uintptr + + // ttbr0App is the value of ttbr0_el1 for applicaton. + ttbr0App uintptr + + // exception vector. + vecCode Vector + + // application context pointer. + appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 // No code. + c.errorType = 1 // User mode. +} + +//go:nosplit +func (c *CPU) GetFaultAddr() (value uintptr) { + return c.faultAddr +} + +//go:nosplit +func (c *CPU) SetTtbr0Kvm(value uintptr) { + c.ttbr0Kvm = value +} + +//go:nosplit +func (c *CPU) SetTtbr0App(value uintptr) { + c.ttbr0App = value +} + +//go:nosplit +func (c *CPU) GetVector() (value Vector) { + return c.vecCode +} + +//go:nosplit +func (c *CPU) SetAppAddr(value uintptr) { + c.appAddr = value +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserASID indicates that the application ASID to be used on switch, + UserASID uint16 + + // KernelASID indicates that the kernel ASID to be used on return, + KernelASID uint16 +} + +func init() { +} diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go new file mode 100644 index 000000000..0dfa42c36 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.go @@ -0,0 +1,60 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// This is an assembly function. +// +// The sysenter function is invoked in two situations: +// +// (1) The guest kernel has executed a system call. +// (2) The guest application has executed a system call. +// +// The interrupt flag is examined to determine whether the system call was +// executed from kernel mode or not and the appropriate stub is called. + +func El1_sync_invalid() +func El1_irq_invalid() +func El1_fiq_invalid() +func El1_error_invalid() + +func El1_sync() +func El1_irq() +func El1_fiq() +func El1_error() + +func El0_sync() +func El0_irq() +func El0_fiq() +func El0_error() + +func El0_sync_invalid() +func El0_irq_invalid() +func El0_fiq_invalid() +func El0_error_invalid() + +func Vectors() + +// Start is the CPU entrypoint. +// +// The CPU state will be set to c.Registers(). +func Start() +func kernelExitToEl1() + +func kernelExitToEl0() + +// Shutdown execution +func Shutdown() diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s new file mode 100644 index 000000000..add2c3e08 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -0,0 +1,591 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +// NB: Offsets are programatically generated (see BUILD). +// +// This file is concatenated with the definitions. + +// Saves a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// + +#define ERET() \ + WORD $0xd69f03e0 + +#define RSV_REG R18_PLATFORM +#define RSV_REG_APP R9 + +#define FPEN_NOTRAP 0x3 +#define FPEN_SHIFT 20 + +#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) + +#define REGISTERS_SAVE(reg, offset) \ + MOVD R0, offset+PTRACE_R0(reg); \ + MOVD R1, offset+PTRACE_R1(reg); \ + MOVD R2, offset+PTRACE_R2(reg); \ + MOVD R3, offset+PTRACE_R3(reg); \ + MOVD R4, offset+PTRACE_R4(reg); \ + MOVD R5, offset+PTRACE_R5(reg); \ + MOVD R6, offset+PTRACE_R6(reg); \ + MOVD R7, offset+PTRACE_R7(reg); \ + MOVD R8, offset+PTRACE_R8(reg); \ + MOVD R10, offset+PTRACE_R10(reg); \ + MOVD R11, offset+PTRACE_R11(reg); \ + MOVD R12, offset+PTRACE_R12(reg); \ + MOVD R13, offset+PTRACE_R13(reg); \ + MOVD R14, offset+PTRACE_R14(reg); \ + MOVD R15, offset+PTRACE_R15(reg); \ + MOVD R16, offset+PTRACE_R16(reg); \ + MOVD R17, offset+PTRACE_R17(reg); \ + MOVD R19, offset+PTRACE_R19(reg); \ + MOVD R20, offset+PTRACE_R20(reg); \ + MOVD R21, offset+PTRACE_R21(reg); \ + MOVD R22, offset+PTRACE_R22(reg); \ + MOVD R23, offset+PTRACE_R23(reg); \ + MOVD R24, offset+PTRACE_R24(reg); \ + MOVD R25, offset+PTRACE_R25(reg); \ + MOVD R26, offset+PTRACE_R26(reg); \ + MOVD R27, offset+PTRACE_R27(reg); \ + MOVD g, offset+PTRACE_R28(reg); \ + MOVD R29, offset+PTRACE_R29(reg); \ + MOVD R30, offset+PTRACE_R30(reg); + +#define REGISTERS_LOAD(reg, offset) \ + MOVD offset+PTRACE_R0(reg), R0; \ + MOVD offset+PTRACE_R1(reg), R1; \ + MOVD offset+PTRACE_R2(reg), R2; \ + MOVD offset+PTRACE_R3(reg), R3; \ + MOVD offset+PTRACE_R4(reg), R4; \ + MOVD offset+PTRACE_R5(reg), R5; \ + MOVD offset+PTRACE_R6(reg), R6; \ + MOVD offset+PTRACE_R7(reg), R7; \ + MOVD offset+PTRACE_R8(reg), R8; \ + MOVD offset+PTRACE_R10(reg), R10; \ + MOVD offset+PTRACE_R11(reg), R11; \ + MOVD offset+PTRACE_R12(reg), R12; \ + MOVD offset+PTRACE_R13(reg), R13; \ + MOVD offset+PTRACE_R14(reg), R14; \ + MOVD offset+PTRACE_R15(reg), R15; \ + MOVD offset+PTRACE_R16(reg), R16; \ + MOVD offset+PTRACE_R17(reg), R17; \ + MOVD offset+PTRACE_R19(reg), R19; \ + MOVD offset+PTRACE_R20(reg), R20; \ + MOVD offset+PTRACE_R21(reg), R21; \ + MOVD offset+PTRACE_R22(reg), R22; \ + MOVD offset+PTRACE_R23(reg), R23; \ + MOVD offset+PTRACE_R24(reg), R24; \ + MOVD offset+PTRACE_R25(reg), R25; \ + MOVD offset+PTRACE_R26(reg), R26; \ + MOVD offset+PTRACE_R27(reg), R27; \ + MOVD offset+PTRACE_R28(reg), g; \ + MOVD offset+PTRACE_R29(reg), R29; \ + MOVD offset+PTRACE_R30(reg), R30; + +//NOP +#define nop31Instructions() \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; + +#define ESR_ELx_EC_UNKNOWN (0x00) +#define ESR_ELx_EC_WFx (0x01) +/* Unallocated EC: 0x02 */ +#define ESR_ELx_EC_CP15_32 (0x03) +#define ESR_ELx_EC_CP15_64 (0x04) +#define ESR_ELx_EC_CP14_MR (0x05) +#define ESR_ELx_EC_CP14_LS (0x06) +#define ESR_ELx_EC_FP_ASIMD (0x07) +#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */ +#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */ +/* Unallocated EC: 0x0A - 0x0B */ +#define ESR_ELx_EC_CP14_64 (0x0C) +/* Unallocated EC: 0x0d */ +#define ESR_ELx_EC_ILL (0x0E) +/* Unallocated EC: 0x0F - 0x10 */ +#define ESR_ELx_EC_SVC32 (0x11) +#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */ +#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */ +/* Unallocated EC: 0x14 */ +#define ESR_ELx_EC_SVC64 (0x15) +#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */ +#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */ +#define ESR_ELx_EC_SYS64 (0x18) +#define ESR_ELx_EC_SVE (0x19) +/* Unallocated EC: 0x1A - 0x1E */ +#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */ +#define ESR_ELx_EC_IABT_LOW (0x20) +#define ESR_ELx_EC_IABT_CUR (0x21) +#define ESR_ELx_EC_PC_ALIGN (0x22) +/* Unallocated EC: 0x23 */ +#define ESR_ELx_EC_DABT_LOW (0x24) +#define ESR_ELx_EC_DABT_CUR (0x25) +#define ESR_ELx_EC_SP_ALIGN (0x26) +/* Unallocated EC: 0x27 */ +#define ESR_ELx_EC_FP_EXC32 (0x28) +/* Unallocated EC: 0x29 - 0x2B */ +#define ESR_ELx_EC_FP_EXC64 (0x2C) +/* Unallocated EC: 0x2D - 0x2E */ +#define ESR_ELx_EC_SERROR (0x2F) +#define ESR_ELx_EC_BREAKPT_LOW (0x30) +#define ESR_ELx_EC_BREAKPT_CUR (0x31) +#define ESR_ELx_EC_SOFTSTP_LOW (0x32) +#define ESR_ELx_EC_SOFTSTP_CUR (0x33) +#define ESR_ELx_EC_WATCHPT_LOW (0x34) +#define ESR_ELx_EC_WATCHPT_CUR (0x35) +/* Unallocated EC: 0x36 - 0x37 */ +#define ESR_ELx_EC_BKPT32 (0x38) +/* Unallocated EC: 0x39 */ +#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */ +/* Unallocted EC: 0x3B */ +#define ESR_ELx_EC_BRK64 (0x3C) +/* Unallocated EC: 0x3D - 0x3F */ +#define ESR_ELx_EC_MAX (0x3F) + +#define ESR_ELx_EC_SHIFT (26) +#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT) +#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT) + +#define ESR_ELx_IL_SHIFT (25) +#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT) +#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1) + +/* ISS field definitions shared by different classes */ +#define ESR_ELx_WNR_SHIFT (6) +#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT) + +/* Asynchronous Error Type */ +#define ESR_ELx_IDS_SHIFT (24) +#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT) +#define ESR_ELx_AET_SHIFT (10) +#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT) + +#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT) + +/* Shared ISS field definitions for Data/Instruction aborts */ +#define ESR_ELx_SET_SHIFT (11) +#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT) +#define ESR_ELx_FnV_SHIFT (10) +#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT) +#define ESR_ELx_EA_SHIFT (9) +#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT) +#define ESR_ELx_S1PTW_SHIFT (7) +#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT) + +/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */ +#define ESR_ELx_FSC (0x3F) +#define ESR_ELx_FSC_TYPE (0x3C) +#define ESR_ELx_FSC_EXTABT (0x10) +#define ESR_ELx_FSC_SERROR (0x11) +#define ESR_ELx_FSC_ACCESS (0x08) +#define ESR_ELx_FSC_FAULT (0x04) +#define ESR_ELx_FSC_PERM (0x0C) + +/* ISS field definitions for Data Aborts */ +#define ESR_ELx_ISV_SHIFT (24) +#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT) +#define ESR_ELx_SAS_SHIFT (22) +#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT) +#define ESR_ELx_SSE_SHIFT (21) +#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT) +#define ESR_ELx_SRT_SHIFT (16) +#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT) +#define ESR_ELx_SF_SHIFT (15) +#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT) +#define ESR_ELx_AR_SHIFT (14) +#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT) +#define ESR_ELx_CM_SHIFT (8) +#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT) + +/* ISS field definitions for exceptions taken in to Hyp */ +#define ESR_ELx_CV (UL(1) << 24) +#define ESR_ELx_COND_SHIFT (20) +#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT) +#define ESR_ELx_WFx_ISS_TI (UL(1) << 0) +#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0) +#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) +#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) + +#define LOAD_KERNEL_ADDRESS(from, to) \ + MOVD from, to; \ + ORR $0xffff000000000000, to, to; + +// LOAD_KERNEL_STACK loads the kernel temporary stack. +#define LOAD_KERNEL_STACK(from) \ + LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \ + MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \ + MOVD RSV_REG, RSP; \ + ISB $15; \ + DSB $15; + +#define SWITCH_TO_APP_PAGETABLE(from) \ + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define SWITCH_TO_KVM_PAGETABLE(from) \ + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define IRQ_ENABLE \ + MSR $2, DAIFSet; + +#define IRQ_DISABLE \ + MSR $2, DAIFClr; + +#define VFP_ENABLE \ + MOVD $FPEN_ENABLE, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define VFP_DISABLE \ + MOVD $0x0, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define KERNEL_ENTRY_FROM_EL0 \ + SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. + STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. + SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + MOVD RSV_REG_APP, R20; \ + LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ + ADD $16, RSP, RSP; \ + MOVD RSV_REG, PTRACE_R18(R20); \ + MOVD RSV_REG_APP, PTRACE_R9(R20); \ + MOVD R20, RSV_REG_APP; \ + WORD $0xd5384003; \ // MRS SPSR_EL1, R3 + MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MRS ELR_EL1, R3; \ + MOVD R3, PTRACE_PC(RSV_REG_APP); \ + WORD $0xd5384103; \ // MRS SP_EL0, R3 + MOVD R3, PTRACE_SP(RSV_REG_APP); + +#define KERNEL_ENTRY_FROM_EL1 \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context + MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \ + WORD $0xd5384004; \ // MRS SPSR_EL1, R4 + MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \ + MRS ELR_EL1, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \ + MOVD RSP, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); + +TEXT ·Halt(SB),NOSPLIT,$0 + // clear bluepill. + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + CMP RSV_REG, R9 + BNE mmio_exit + MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) +mmio_exit: + // Disable fpsimd. + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, CPU_LAZY_VFP(RSV_REG) + VFP_DISABLE + + // MMIO_EXIT. + MOVD $0, R9 + MOVD R0, 0xffff000000001000(R9) + B ·kernelExitToEl1(SB) + +TEXT ·Shutdown(SB),NOSPLIT,$0 + // PSCI EVENT. + MOVD $0x84000009, R0 + HVC $0 + +// See kernel.go. +TEXT ·Current(SB),NOSPLIT,$0-8 + MOVD CPU_SELF(RSV_REG), R8 + MOVD R8, ret+0(FP) + RET + +#define STACK_FRAME_SIZE 16 + +TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 + ERET() + +TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 + ERET() + +TEXT ·Start(SB),NOSPLIT,$0 + IRQ_DISABLE + MOVD R8, RSV_REG + ORR $0xffff000000000000, RSV_REG, RSV_REG + WORD $0xd518d092 //MSR R18, TPIDR_EL1 + + B ·kernelExitToEl1(SB) + +TEXT ·El1_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL1 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_DABT_CUR, R24 + BEQ el1_da + CMP $ESR_ELx_EC_IABT_CUR, R24 + BEQ el1_ia + CMP $ESR_ELx_EC_SYS64, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el1_svc + CMP $ESR_ELx_EC_BREAKPT_CUR, R24 + BGE el1_dbg + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el1_fpsimd_acc + B el1_invalid + +el1_da: + B ·Halt(SB) + +el1_ia: + B ·Halt(SB) + +el1_sp_pc: + B ·Shutdown(SB) + +el1_undef: + B ·Shutdown(SB) + +el1_svc: + B ·Halt(SB) + +el1_dbg: + B ·Shutdown(SB) + +el1_fpsimd_acc: + VFP_ENABLE + B ·kernelExitToEl1(SB) // Resume. + +el1_invalid: + B ·Shutdown(SB) + +TEXT ·El1_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el0_svc + CMP $ESR_ELx_EC_DABT_LOW, R24 + BEQ el0_da + CMP $ESR_ELx_EC_IABT_LOW, R24 + BEQ el0_ia + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el0_fpsimd_acc + CMP $ESR_ELx_EC_SVE, R24 + BEQ el0_sve_acc + CMP $ESR_ELx_EC_FP_EXC64, R24 + BEQ el0_fpsimd_exc + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el0_undef + CMP $ESR_ELx_EC_BREAKPT_LOW, R24 + BGE el0_dbg + B el0_invalid + +el0_svc: + B ·Halt(SB) + +el0_da: + B ·Halt(SB) + +el0_ia: + B ·Shutdown(SB) + +el0_fpsimd_acc: + B ·Shutdown(SB) + +el0_sve_acc: + B ·Shutdown(SB) + +el0_fpsimd_exc: + B ·Shutdown(SB) + +el0_sp_pc: + B ·Shutdown(SB) + +el0_undef: + B ·Shutdown(SB) + +el0_dbg: + B ·Shutdown(SB) + +el0_invalid: + B ·Shutdown(SB) + +TEXT ·El0_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·Vectors(SB),NOSPLIT,$0 + B ·El1_sync_invalid(SB) + nop31Instructions() + B ·El1_irq_invalid(SB) + nop31Instructions() + B ·El1_fiq_invalid(SB) + nop31Instructions() + B ·El1_error_invalid(SB) + nop31Instructions() + + B ·El1_sync(SB) + nop31Instructions() + B ·El1_irq(SB) + nop31Instructions() + B ·El1_fiq(SB) + nop31Instructions() + B ·El1_error(SB) + nop31Instructions() + + B ·El0_sync(SB) + nop31Instructions() + B ·El0_irq(SB) + nop31Instructions() + B ·El0_fiq(SB) + nop31Instructions() + B ·El0_error(SB) + nop31Instructions() + + B ·El0_sync_invalid(SB) + nop31Instructions() + B ·El0_irq_invalid(SB) + nop31Instructions() + B ·El0_fiq_invalid(SB) + nop31Instructions() + B ·El0_error_invalid(SB) + nop31Instructions() + + WORD $0xd503201f //nop + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index d7029d5a9..42076fb04 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -1,20 +1,27 @@ load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") +go_template_instance( + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", + package = "main", + template = "//pkg/sentry/platform/ring0:defs_arm64", +) go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs", + template = "//pkg/sentry/platform/ring0:defs_amd64", ) go_binary( name = "gen_offsets", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "main.go", ], visibility = ["//pkg/sentry/platform/ring0:__pkg__"], diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go new file mode 100644 index 000000000..ed82a131e --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -0,0 +1,58 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// init initializes architecture-specific state. +func (k *Kernel) init(opts KernelOpts) { + // Save the root page tables. + k.PageTables = opts.PageTables +} + +// init initializes architecture-specific state. +func (c *CPU) init() { + // Set the kernel stack pointer(virtual address). + c.registers.Sp = uint64(c.StackTop()) + +} + +// StackTop returns the kernel's stack address. +// +//go:nosplit +func (c *CPU) StackTop() uint64 { + return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack)) +} + +// IsCanonical indicates whether addr is canonical per the arm64 spec. +// +//go:nosplit +func IsCanonical(addr uint64) bool { + return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 +} + +//go:nosplit +func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + // Sanitize registers. + regs := switchOpts.Registers + + regs.Pstate &= ^uint64(UserFlagsClear) + regs.Pstate |= UserFlagsSet + kernelExitToEl0() + vector = c.vecCode + + // Perform the switch. + return +} diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go new file mode 100644 index 000000000..8bcfe1032 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -0,0 +1,39 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// CPACREL1 returns the value of the CPACR_EL1 register. +func CPACREL1() (value uintptr) + +// FPCR returns the value of FPCR register. +func FPCR() (value uintptr) + +// SetFPCR writes the FPCR value. +func SetFPCR(value uintptr) + +// FPSR returns the value of FPSR register. +func FPSR() (value uintptr) + +// SetFPSR writes the FPSR value. +func SetFPSR(value uintptr) + +// SaveVRegs saves V0-V31 registers. +// V0-V31: 32 128-bit registers for floating point and simd. +func SaveVRegs(*byte) + +// LoadVRegs loads V0-V31 registers. +func LoadVRegs(*byte) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s new file mode 100644 index 000000000..1c9171004 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -0,0 +1,118 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +TEXT ·CPACREL1(SB),NOSPLIT,$0-8 + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + WORD $0xd53b4201 // MRS NZCV, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·GetFPSR(SB),NOSPLIT,$0-8 + WORD $0xd53b4421 // MRS FPSR, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4201 // MSR R1, NZCV + RET + +TEXT ·SetFPSR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4421 // MSR R1, FPSR + RET + +TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD F0, 16*1(R0) + FMOVD F1, 16*2(R0) + FMOVD F2, 16*3(R0) + FMOVD F3, 16*4(R0) + FMOVD F4, 16*5(R0) + FMOVD F5, 16*6(R0) + FMOVD F6, 16*7(R0) + FMOVD F7, 16*8(R0) + FMOVD F8, 16*9(R0) + FMOVD F9, 16*10(R0) + FMOVD F10, 16*11(R0) + FMOVD F11, 16*12(R0) + FMOVD F12, 16*13(R0) + FMOVD F13, 16*14(R0) + FMOVD F14, 16*15(R0) + FMOVD F15, 16*16(R0) + FMOVD F16, 16*17(R0) + FMOVD F17, 16*18(R0) + FMOVD F18, 16*19(R0) + FMOVD F19, 16*20(R0) + FMOVD F20, 16*21(R0) + FMOVD F21, 16*22(R0) + FMOVD F22, 16*23(R0) + FMOVD F23, 16*24(R0) + FMOVD F24, 16*25(R0) + FMOVD F25, 16*26(R0) + FMOVD F26, 16*27(R0) + FMOVD F27, 16*28(R0) + FMOVD F28, 16*29(R0) + FMOVD F29, 16*30(R0) + FMOVD F30, 16*31(R0) + FMOVD F31, 16*32(R0) + ISB $15 + + RET + +TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD 16*1(R0), F0 + FMOVD 16*2(R0), F1 + FMOVD 16*3(R0), F2 + FMOVD 16*4(R0), F3 + FMOVD 16*5(R0), F4 + FMOVD 16*6(R0), F5 + FMOVD 16*7(R0), F6 + FMOVD 16*8(R0), F7 + FMOVD 16*9(R0), F8 + FMOVD 16*10(R0), F9 + FMOVD 16*11(R0), F10 + FMOVD 16*12(R0), F11 + FMOVD 16*13(R0), F12 + FMOVD 16*14(R0), F13 + FMOVD 16*15(R0), F14 + FMOVD 16*16(R0), F15 + FMOVD 16*17(R0), F16 + FMOVD 16*18(R0), F17 + FMOVD 16*19(R0), F18 + FMOVD 16*20(R0), F19 + FMOVD 16*21(R0), F20 + FMOVD 16*22(R0), F21 + FMOVD 16*23(R0), F22 + FMOVD 16*24(R0), F23 + FMOVD 16*25(R0), F24 + FMOVD 16*26(R0), F25 + FMOVD 16*27(R0), F26 + FMOVD 16*28(R0), F27 + FMOVD 16*29(R0), F28 + FMOVD 16*30(R0), F29 + FMOVD 16*31(R0), F30 + FMOVD 16*32(R0), F31 + ISB $15 + + RET diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go new file mode 100644 index 000000000..cd2a65f97 --- /dev/null +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -0,0 +1,125 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "fmt" + "io" + "reflect" + "syscall" +) + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) + fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) + fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) + fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) + + fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) + fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) + fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) + fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + + fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) + fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) + fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) + fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + + fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) + fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) + fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) + fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + + fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) + fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) + fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) + fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) + fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) + fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + + fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) + fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) + fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) + fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) + fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) + fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) + fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) + fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) + fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) + fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer()) +} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 3b95af617..e2e15ba5c 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,14 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") +config_setting( + name = "aarch64", + constraint_values = ["@bazel_tools//platforms:aarch64"], +) go_template( name = "generic_walker", - srcs = [ - "walker_amd64.go", - ], + srcs = ["walker_amd64.go"], opt_types = [ "Visitor", ], @@ -76,9 +79,13 @@ go_library( "allocator.go", "allocator_unsafe.go", "pagetables.go", + "pagetables_aarch64.go", "pagetables_amd64.go", + "pagetables_arm64.go", "pagetables_x86.go", "pcids_x86.go", + "walker_amd64.go", + "walker_arm64.go", "walker_empty.go", "walker_lookup.go", "walker_map.go", @@ -97,6 +104,7 @@ go_test( size = "small", srcs = [ "pagetables_amd64_test.go", + "pagetables_arm64_test.go", "pagetables_test.go", "walker_check.go", ], diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 904f1a6de..30c64a372 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -48,15 +48,6 @@ func New(a Allocator) *PageTables { return p } -// Init initializes a set of PageTables. -// -//go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) -} - // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go new file mode 100644 index 000000000..e78424766 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -0,0 +1,212 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// archPageTables is architecture-specific data. +type archPageTables struct { + // root is the pagetable root for kernel space. + root *PTEs + + // rootPhysical is the cached physical address of the root. + // + // This is saved only to prevent constant translation. + rootPhysical uintptr + + asid uint16 +} + +// TTBR0_EL1 returns the translation table base register 0. +// +//go:nosplit +func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// TTBR1_EL1 returns the translation table base register 1. +// +//go:nosplit +func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// Bits in page table entries. +const ( + typeTable = 0x3 << 0 + typeSect = 0x1 << 0 + typePage = 0x3 << 0 + pteValid = 0x1 << 0 + pteTableBit = 0x1 << 1 + pteTypeMask = 0x3 << 0 + present = pteValid | pteTableBit + user = 0x1 << 6 /* AP[1] */ + readOnly = 0x1 << 7 /* AP[2] */ + accessed = 0x1 << 10 + dbm = 0x1 << 51 + writable = dbm + cont = 0x1 << 52 + pxn = 0x1 << 53 + xn = 0x1 << 54 + dirty = 0x1 << 55 + nG = 0x1 << 11 + shared = 0x3 << 8 +) + +const ( + mtNormal = 0x4 << 2 +) + +const ( + executeDisable = xn + optionMask = 0xfff | 0xfff<<48 + protDefault = accessed | shared | mtNormal +) + +// MapOpts are x86 options. +type MapOpts struct { + // AccessType defines permissions. + AccessType usermem.AccessType + + // Global indicates the page is globally accessible. + Global bool + + // User indicates the page is a user page. + User bool +} + +// PTE is a page table entry. +type PTE uintptr + +// Clear clears this PTE, including sect page information. +// +//go:nosplit +func (p *PTE) Clear() { + atomic.StoreUintptr((*uintptr)(p), 0) +} + +// Valid returns true iff this entry is valid. +// +//go:nosplit +func (p *PTE) Valid() bool { + return atomic.LoadUintptr((*uintptr)(p))&present != 0 +} + +// Opts returns the PTE options. +// +// These are all options except Valid and Sect. +// +//go:nosplit +func (p *PTE) Opts() MapOpts { + v := atomic.LoadUintptr((*uintptr)(p)) + + return MapOpts{ + AccessType: usermem.AccessType{ + Read: true, + Write: v&readOnly == 0, + Execute: v&xn == 0, + }, + Global: v&nG == 0, + User: v&user != 0, + } +} + +// SetSect sets this page as a sect page. +// +// The page must not be valid or a panic will result. +// +//go:nosplit +func (p *PTE) SetSect() { + if p.Valid() { + // This is not allowed. + panic("SetSect called on valid page!") + } + atomic.StoreUintptr((*uintptr)(p), typeSect) +} + +// IsSect returns true iff this page is a sect page. +// +//go:nosplit +func (p *PTE) IsSect() bool { + return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect +} + +// Set sets this PTE value. +// +// This does not change the sect page property. +// +//go:nosplit +func (p *PTE) Set(addr uintptr, opts MapOpts) { + if !opts.AccessType.Any() { + p.Clear() + return + } + v := (addr &^ optionMask) | protDefault | nG | readOnly + + if p.IsSect() { + // Note that this is inherited from the previous instance. Set + // does not change the value of Sect. See above. + v |= typeSect + } else { + v |= typePage + } + + if opts.Global { + v = v &^ nG + } + + if opts.AccessType.Execute { + v = v &^ executeDisable + } else { + v |= executeDisable + } + if opts.AccessType.Write { + v = v &^ readOnly + } + + if opts.User { + v |= user + } else { + v = v &^ user + } + atomic.StoreUintptr((*uintptr)(p), v) +} + +// setPageTable sets this PTE value and forces the write bit and sect bit to +// be cleared. This is used explicitly for breaking sect pages. +// +//go:nosplit +func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) { + addr := pt.Allocator.PhysicalFor(ptes) + if addr&^optionMask != addr { + // This should never happen. + panic("unaligned physical address!") + } + v := addr | typeTable | protDefault + atomic.StoreUintptr((*uintptr)(p), v) +} + +// Address extracts the address. This should only be used if Valid returns true. +// +//go:nosplit +func (p *PTE) Address() uintptr { + return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 7aa6c524e..0c153cf8c 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -41,5 +41,14 @@ const ( entriesPerPage = 512 ) +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go new file mode 100644 index 000000000..1a49f12a2 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -0,0 +1,57 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +// Address constraints. +// +// The lowerTop and upperBottom currently apply to four-level pagetables; +// additional refactoring would be necessary to support five-level pagetables. +const ( + lowerTop = 0x0000ffffffffffff + upperBottom = 0xffff000000000000 + pteShift = 12 + pmdShift = 21 + pudShift = 30 + pgdShift = 39 + + pteMask = 0x1ff << pteShift + pmdMask = 0x1ff << pmdShift + pudMask = 0x1ff << pudShift + pgdMask = 0x1ff << pgdShift + + pteSize = 1 << pteShift + pmdSize = 1 << pmdShift + pudSize = 1 << pudShift + pgdSize = 1 << pgdShift + + ttbrASIDOffset = 55 + ttbrASIDMask = 0xff + + entriesPerPage = 512 +) + +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) + p.archPageTables.root = p.Allocator.NewPTEs() + p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// PTEs is a collection of entries. +type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go new file mode 100644 index 000000000..254116233 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go @@ -0,0 +1,80 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func Test2MAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a huge page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) + + pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) + pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, + {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, + }) +} + +func Test1GAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a super page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit1GPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a super page and knock out the middle. + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit2MPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a huge page and knock out the middle. + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go new file mode 100644 index 000000000..c261d393a --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +// Visitor is a generic type. +type Visitor interface { + // visit is called on each PTE. + visit(start uintptr, pte *PTE, align uintptr) + + // requiresAlloc indicates that new entries should be allocated within + // the walked range. + requiresAlloc() bool + + // requiresSplit indicates that entries in the given range should be + // split if they are huge or jumbo pages. + requiresSplit() bool +} + +// Walker walks page tables. +type Walker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor Visitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is sect pages. If a valid sect page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of sect pages whenever +// possible. Whether a sect page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *Walker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func next(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *Walker) iterateRangeCanonical(start, end uintptr) { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + start = next(start, pgdSize) + continue + } + + // Allocate a new pgd. + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + // Map the next level. + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPUDEntries++ + start = next(start, pudSize) + continue + } + + // This level has 1-GB sect pages. Is this + // entire region at least as large as a single + // PUD entry? If so, we can skip allocating a + // new page for the pmd. + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = next(start, pudSize) + continue + } + } + + // Allocate a new pud. + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { + // Install the relevant entries. + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + // A sect page to be checked directly. + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + // Might have been cleared. + if !pudEntry.Valid() { + clearPUDEntries++ + } + + // Note that the sect page was changed. + start = next(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + // Map the next level, since this is valid. + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPMDEntries++ + start = next(start, pmdSize) + continue + } + + // This level has 2-MB huge pages. If this + // region is contined in a single PMD entry? + // As above, we can skip allocating a new page. + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = next(start, pmdSize) + continue + } + } + + // Allocate a new pmd. + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) { + // Install the relevant entries. + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + // A huge page to be checked directly. + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + // Might have been cleared. + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + // Note that the huge page was changed. + start = next(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + // Map the next level, since this is valid. + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + // At this point, we are guaranteed that start%pteSize == 0. + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + // Note that the pte was changed. + start += pteSize + continue + } + + // Check if we no longer need this page. + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + // Check if we no longer need this page. + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + // Check if we no longer need this page. + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 924d8a6d6..6769cd0a5 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index fd6dc8e6e..884020f7b 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go index 2f65db70b..ba1f9043d 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sentry/sighandling/sighandling.go @@ -16,7 +16,6 @@ package sighandling import ( - "fmt" "os" "os/signal" "reflect" @@ -31,37 +30,25 @@ const numSignals = 32 // handleSignals listens for incoming signals and calls the given handler // function. // -// It starts when the start channel is closed, stops when the stop channel -// is closed, and closes done once it will no longer deliver signals to k. -func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) { +// It stops when the stop channel is closed. The done channel is closed once it +// will no longer deliver signals to k. +func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) { // Build a select case. - sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}} + sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}} for _, sigchan := range sigchans { sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)}) } - started := false for { // Wait for a notification. index, _, ok := reflect.Select(sc) - // Was it the start / stop channel? + // Was it the stop channel? if index == 0 { if !ok { - if !started { - // start channel; start forwarding and - // swap this case for the stop channel - // to select stop requests. - started = true - sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)} - } else { - // stop channel; stop forwarding and - // clear this case so it is never - // selected again. - started = false - close(done) - sc[0].Chan = reflect.Value{} - } + // Stop forwarding and notify that it's done. + close(done) + return } continue } @@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, // Otherwise, it was a signal on channel N. Index 0 represents the stop // channel, so index N represents the channel for signal N. - signal := linux.Signal(index) - - if !started { - // Kernel cannot receive signals, either because it is - // not ready yet or is shutting down. - // - // Kill ourselves if this signal would have killed the - // process before PrepareForwarding was called. i.e., all - // _SigKill signals; see Go - // src/runtime/sigtab_linux_generic.go. - // - // Otherwise ignore the signal. - // - // TODO(b/114489875): Drop in Go 1.12, which uses tgkill - // in runtime.raise. - switch signal { - case linux.SIGHUP, linux.SIGINT, linux.SIGTERM: - dieFromSignal(signal) - panic(fmt.Sprintf("Failed to die from signal %d", signal)) - default: - continue - } - } - - // Pass the signal to the handler. - handler(signal) + handler(linux.Signal(index)) } } -// PrepareHandler ensures that synchronous signals are passed to the given -// handler function and returns a callback that starts signal delivery, which -// itself returns a callback that stops signal handling. +// StartSignalForwarding ensures that synchronous signals are passed to the +// given handler function and returns a callback that stops signal delivery. // // Note that this function permanently takes over signal handling. After the // stop callback, signals revert to the default Go runtime behavior, which // cannot be overridden with external calls to signal.Notify. -func PrepareHandler(handler func(linux.Signal)) func() func() { - start := make(chan struct{}) +func StartSignalForwarding(handler func(linux.Signal)) func() { stop := make(chan struct{}) done := make(chan struct{}) @@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() { signal.Notify(sigchan, syscall.Signal(sig)) } // Start up our listener. - go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. + go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. - return func() func() { - close(start) - return func() { - close(stop) - <-done - } + return func() { + close(stop) + <-done } } diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index eace3766d..1ebe22d34 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -15,15 +15,13 @@ package sighandling import ( - "fmt" - "runtime" "syscall" "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" ) -// TODO(b/34161764): Move to pkg/abi/linux along with definitions in +// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in // pkg/sentry/arch. type sigaction struct { handler uintptr @@ -48,27 +46,3 @@ func IgnoreChildStop() error { return nil } - -// dieFromSignal kills the current process with sig. -// -// Preconditions: The default action of sig is termination. -func dieFromSignal(sig linux.Signal) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - sa := sigaction{handler: linux.SIG_DFL} - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigaction failed: %v", e)) - } - - set := linux.MakeSignalSet(sig) - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigprocmask failed: %v", e)) - } - - if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil { - panic(fmt.Sprintf("tgkill failed: %v", err)) - } - - panic("failed to die") -} diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 3300f9a6b..26176b10d 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "socket", srcs = ["socket.go"], diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 81dbd7309..357517ed4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "control", srcs = ["control.go"], @@ -17,6 +17,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4e95101b7..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" @@ -194,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -239,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -320,35 +321,109 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { buf, linux.SOL_TCP, linux.TCP_INQ, - 4, + t.Arch().Width(), inq, ) } +// PackTOS packs an IP_TOS socket control message. +func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_TOS, + t.Arch().Width(), + tos, + ) +} + +// PackTClass packs an IPV6_TCLASS socket control message. +func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_TCLASS, + t.Arch().Width(), + tClass, + ) +} + +// PackControlMessages packs control messages into the given buffer. +// +// We skip control messages specific to Unix domain sockets. +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { + if cmsgs.IP.HasTimestamp { + buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) + } + + if cmsgs.IP.HasInq { + // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. + buf = PackInq(t, cmsgs.IP.Inq, buf) + } + + if cmsgs.IP.HasTOS { + buf = PackTOS(t, cmsgs.IP.TOS, buf) + } + + if cmsgs.IP.HasTClass { + buf = PackTClass(t, cmsgs.IP.TClass, buf) + } + + return buf +} + +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( - fds linux.ControlMessageRights - haveCreds bool - creds linux.ControlMessageCredentials + cmsgs socket.ControlMessages + fds linux.ControlMessageRights ) for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return transport.ControlMessages{}, syserror.EINVAL + return cmsgs, syserror.EINVAL } var h linux.ControlMessageHeader binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } if h.Length > uint64(len(buf)-i) { - return transport.ControlMessages{}, syserror.EINVAL - } - if h.Level != linux.SOL_SOCKET { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } i += linux.SizeOfControlMessageHeader @@ -358,59 +433,79 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport. // sizeof(long) in CMSG_ALIGN. width := t.Arch().Width() - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) - numRights := rightsSize / linux.SizeOfControlMessageRight - - if len(fds)+numRights > linux.SCM_MAX_FD { - return transport.ControlMessages{}, syserror.EINVAL + switch h.Level { + case linux.SOL_SOCKET: + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) + numRights := rightsSize / linux.SizeOfControlMessageRight + + if len(fds)+numRights > linux.SCM_MAX_FD { + return socket.ControlMessages{}, syserror.EINVAL + } + + for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { + fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + } + + i += AlignUp(length, width) + + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + return socket.ControlMessages{}, syserror.EINVAL + } + + var creds linux.ControlMessageCredentials + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + scmCreds, err := NewSCMCredentials(t, creds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Credentials = scmCreds + i += AlignUp(length, width) + + default: + // Unknown message type. + return socket.ControlMessages{}, syserror.EINVAL } - - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + case linux.SOL_IP: + switch h.Type { + case linux.IP_TOS: + cmsgs.IP.HasTOS = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - i += AlignUp(length, width) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { - return transport.ControlMessages{}, syserror.EINVAL + case linux.SOL_IPV6: + switch h.Type { + case linux.IPV6_TCLASS: + cmsgs.IP.HasTClass = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) - haveCreds = true - i += AlignUp(length, width) - default: - // Unknown message type. - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } } - var credentials SCMCredentials - if haveCreds { - var err error - if credentials, err = NewSCMCredentials(t, creds); err != nil { - return transport.ControlMessages{}, err - } - } else { - credentials = makeCreds(t, socketOrEndpoint) + if cmsgs.Unix.Credentials == nil { + cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint) } - var rights SCMRights if len(fds) > 0 { - var err error - if rights, err = NewSCMRights(t, fds); err != nil { - return transport.ControlMessages{}, err + rights, err := NewSCMRights(t, fds) + if err != nil { + return socket.ControlMessages{}, err } + cmsgs.Unix.Rights = rights } - if credentials == nil && rights == nil { - return transport.ControlMessages{}, nil - } - - return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil + return cmsgs, nil } func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a951f1bb0..4c44c7c0f 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "hostinet", srcs = [ @@ -29,9 +29,12 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/control", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 92beb1bcf..c957b0f1d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -18,6 +18,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/fdnotifier" @@ -29,6 +30,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -41,6 +43,10 @@ const ( // sizeofSockaddr is the size in bytes of the largest sockaddr type // supported by this package. sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) + + // maxControlLen is the maximum size of a control message buffer used in a + // recvmsg or sendmsg syscall. + maxControlLen = 1024 ) // socketOperations implements fs.FileOperations and socket.Socket for a socket @@ -281,26 +287,32 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 - case syscall.SO_LINGER: + case linux.SO_LINGER: optlen = syscall.SizeofLinger } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 - case syscall.TCP_INFO: + case linux.TCP_INFO: optlen = int(linux.SizeOfTCPInfo) } } + if optlen == 0 { return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT } @@ -320,19 +332,24 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 } } @@ -354,11 +371,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that netstack/tcpip/transport/unix doesn't understand. Kill the + // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument @@ -370,6 +387,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddrBuf = make([]byte, sizeofSockaddr) } + var controlBuf []byte var msgFlags int recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -384,11 +402,6 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We always do a non-blocking recv*(). sysflags := flags | syscall.MSG_DONTWAIT - if dsts.NumBlocks() == 1 { - // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) - } - iovs := iovecsFromBlockSeq(dsts) msg := syscall.Msghdr{ Iov: &iovs[0], @@ -398,12 +411,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Name = &senderAddrBuf[0] msg.Namelen = uint32(len(senderAddrBuf)) } + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) + controlLen = uint64(msg.Controllen) return n, nil }) @@ -429,14 +451,38 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err = dst.CopyOutFrom(t, recvmsgToBlocks) } } - - // We don't allow control messages. - msgFlags &^= linux.MSG_CTRUNC + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) + + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } + + controlMessages := socket.ControlMessages{} + for _, unixCmsg := range unixControlMessages { + switch unixCmsg.Header.Level { + case syscall.SOL_IP: + switch unixCmsg.Header.Type { + case syscall.IP_TOS: + controlMessages.IP.HasTOS = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + } + case syscall.SOL_IPV6: + switch unixCmsg.Header.Type { + case syscall.IPV6_TCLASS: + controlMessages.IP.HasTClass = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + } + } + } + + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil } // SendMsg implements socket.Socket.SendMsg. @@ -446,6 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.ErrInvalidArgument } + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { @@ -458,7 +512,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // We always do a non-blocking send*(). sysflags := flags | syscall.MSG_DONTWAIT - if srcs.NumBlocks() == 1 { + if srcs.NumBlocks() == 1 && len(controlBuf) == 0 { // Skip allocating []syscall.Iovec. src := srcs.Head() n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) @@ -477,6 +531,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] msg.Name = &to[0] msg.Namelen = uint32(len(to)) } + if len(controlBuf) != 0 { + msg.Control = &controlBuf[0] + msg.Controllen = uint64(len(controlBuf)) + } return sendmsg(s.fd, &msg, sysflags) }) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3a4fdec47..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -16,8 +16,11 @@ package hostinet import ( "fmt" + "io" "io/ioutil" "os" + "reflect" + "strconv" "strings" "syscall" @@ -26,7 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -51,6 +56,8 @@ type Stack struct { tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool + netDevFile *os.File + netSNMPFile *os.File } // NewStack returns an empty Stack containing no configuration. @@ -98,6 +105,18 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if TCP SACK if enabled, setting to true") } + if f, err := os.Open("/proc/net/dev"); err != nil { + log.Warningf("Failed to open /proc/net/dev: %v", err) + } else { + s.netDevFile = f + } + + if f, err := os.Open("/proc/net/snmp"); err != nil { + log.Warningf("Failed to open /proc/net/snmp: %v", err) + } else { + s.netSNMPFile = f + } + return nil } @@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// getLine reads one line from proc file, with specified prefix. +// The last argument, withHeader, specifies if it contains line header. +func getLine(f *os.File, prefix string, withHeader bool) string { + data := make([]byte, 4096) + + if _, err := f.Seek(0, 0); err != nil { + return "" + } + + if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF { + return "" + } + + prefix = prefix + ":" + lines := strings.Split(string(data), "\n") + for _, l := range lines { + l = strings.TrimSpace(l) + if strings.HasPrefix(l, prefix) { + if withHeader { + withHeader = false + continue + } + return l + } + } + return "" +} + +func toSlice(i interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(i)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserror.EOPNOTSUPP + var ( + snmpTCP bool + rawLine string + sliceStat []uint64 + ) + + switch stat.(type) { + case *inet.StatDev: + if s.netDevFile == nil { + return fmt.Errorf("/proc/net/dev is not opened for hostinet") + } + rawLine = getLine(s.netDevFile, arg, false /* with no header */) + case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite: + if s.netSNMPFile == nil { + return fmt.Errorf("/proc/net/snmp is not opened for hostinet") + } + rawLine = getLine(s.netSNMPFile, arg, true) + default: + return syserr.ErrEndpointOperation.ToError() + } + + if rawLine == "" { + return fmt.Errorf("Failed to get raw line") + } + + parts := strings.SplitN(rawLine, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("Failed to get prefix from: %q", rawLine) + } + + sliceStat = toSlice(stat) + fields := strings.Fields(strings.TrimSpace(parts[1])) + if len(fields) != len(sliceStat) { + return fmt.Errorf("Failed to parse fields: %q", rawLine) + } + if _, ok := stat.(*inet.StatSNMPTCP); ok { + snmpTCP = true + } + for i := 0; i < len(sliceStat); i++ { + var err error + if snmpTCP && i == 3 { + var tmp int64 + // MaxConn field is signed, RFC 2012. + tmp, err = strconv.ParseInt(fields[i], 10, 64) + sliceStat[i] = uint64(tmp) // Convert back to int before use. + } else { + sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) + } + if err != nil { + return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + } + } + + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 354a0d6ee..5eb06bbf4 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "netfilter", srcs = [ diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 45ebb2a0e..79589e3c8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "netlink", srcs = [ @@ -20,7 +20,9 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..463544c1a 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "port", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 689cad997..be005df24 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -30,6 +30,13 @@ type Protocol interface { // Protocol returns the Linux netlink protocol value. Protocol() int + // CanSend returns true if this protocol may ever send messages. + // + // TODO(gvisor.dev/issue/1119): This is a workaround to allow + // advertising support for otherwise unimplemented features on sockets + // that will never send messages, thus making those features no-ops. + CanSend() bool + // ProcessMessage processes a single message from userspace. // // If err == nil, any messages added to ms will be sent back to the diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 5dc8533ec..1d4912753 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "route", srcs = ["protocol.go"], diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index cc70ac237..6b4a0ecf4 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int { return linux.NETLINK_ROUTE } +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return true +} + // dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests. func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d0aab293d..4a1b87a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -27,7 +27,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -52,6 +54,8 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) +var errNoFilter = syserr.New("no filter attached", linux.ENOENT) + // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() @@ -60,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice() // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -103,9 +107,19 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool + + // filter indicates that this socket has a BPF filter "installed". + // + // TODO(gvisor.dev/issue/1119): We don't actually support filtering, + // this is just bookkeeping for tracking add/remove. + filter bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -171,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *Socket) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *Socket) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -308,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. // We don't have limit on receiving size. return int32(math.MaxInt32), nil + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred int32 + if s.Passcred() { + passcred = 1 + } + return passcred, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -347,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -354,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + + case linux.SO_ATTACH_FILTER: + // TODO(gvisor.dev/issue/1119): We don't actually + // support filtering. If this socket can't ever send + // messages, then there is nothing to filter and we can + // advertise support. Otherwise, be conservative and + // return an error. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + s.filter = true + s.mu.Unlock() + return nil + + case linux.SO_DETACH_FILTER: + // TODO(gvisor.dev/issue/1119): See above. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + filter := s.filter + s.filter = false + s.mu.Unlock() + + if !filter { + return errNoFilter + } + + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -416,6 +504,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have Peek: flags&linux.MSG_PEEK != 0, } + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. However, the memory manager for the destination will not read + // the endpoint if the destination is zero length. + // + // In order for the endpoint to be read when the destination size is zero, + // we must cause a read of the endpoint by using a separate fake zero + // length block sequence and calling the EndpointReader directly. + if trunc && dst.Addrs.NumBytes() == 0 { + // Perform a read to a zero byte block sequence. We can ignore the + // original destination since it was zero bytes. The length returned by + // ReadToBlocks is ignored and we return the full message length to comply + // with MSG_TRUNC. + _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) + return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + } + if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { @@ -464,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. @@ -472,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -499,7 +630,10 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error PortID: uint32(ms.PortID), }) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + // Add the dump_done_errno payload. + m.Put(int64(0)) + + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD new file mode 100644 index 000000000..0777f3baf --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "uevent", + srcs = ["protocol.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netlink", + "//pkg/syserr", + ], +) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go new file mode 100644 index 000000000..b5d7808d7 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol. +// +// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does +// not support any device events, so these sockets never send any messages. +package uevent + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/syserr" +) + +// Protocol implements netlink.Protocol. +// +// +stateify savable +type Protocol struct{} + +var _ netlink.Protocol = (*Protocol)(nil) + +// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol. +func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { + return &Protocol{}, nil +} + +// Protocol implements netlink.Protocol.Protocol. +func (p *Protocol) Protocol() int { + return linux.NETLINK_KOBJECT_UEVENT +} + +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return false +} + +// ProcessMessage implements netlink.Protocol.ProcessMessage. +func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { + // Silently ignore all messages. + return nil +} + +// init registers the NETLINK_KOBJECT_UEVENT provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol) +} diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD index e927821e1..e414d8055 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,17 +1,17 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( - name = "epsocket", + name = "netstack", srcs = [ "device.go", - "epsocket.go", + "netstack.go", "provider.go", "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket", + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go index 85484d5b1..fbeb89fb8 100644 --- a/pkg/sentry/socket/epsocket/device.go +++ b/pkg/sentry/socket/netstack/device.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import "gvisor.dev/gvisor/pkg/sentry/device" -// epsocketDevice is the endpoint socket virtual device. -var epsocketDevice = device.NewAnonDevice() +// netstackDevice is the endpoint socket virtual device. +var netstackDevice = device.NewAnonDevice() diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go index 0e37ce61b..140851c17 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package epsocket provides an implementation of the socket.Socket interface +// Package netstack provides an implementation of the socket.Socket interface // that is backed by a tcpip.Endpoint. // // It does not depend on any particular endpoint implementation, and thus can @@ -22,10 +22,11 @@ // Lock ordering: netstack => mm: ioSequencePayload copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. -package epsocket +package netstack import ( "bytes" + "io" "math" "reflect" "sync" @@ -52,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -136,15 +138,21 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -154,6 +162,7 @@ var Metrics = tcpip.Stats{ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), + SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."), ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), @@ -169,13 +178,18 @@ var Metrics = tcpip.Stats{ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."), ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."), MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), - PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."), + PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), + PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), }, } +// DefaultTTL is linux's default TTL. All network protocols in all stacks used +// with this package must have this value set as their default TTL. +const DefaultTTL = 64 + const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) // ntohs converts a 16-bit number from network byte order to host byte order. It // assumes that the host is little endian. @@ -208,6 +222,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -227,7 +245,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -258,20 +275,20 @@ type SocketOperations struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket - // level, because it takes into account data from readView. + // sockOptInq corresponds to TCP_INQ. It is implemented at this level + // because it takes into account data from readView. sockOptInq bool } // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { return nil, syserr.TranslateNetstackError(err) } } - dirent := socket.NewDirent(t, epsocketDevice) + dirent := socket.NewDirent(t, netstackDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ Queue: queue, @@ -284,6 +301,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) +var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -295,12 +313,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address { } // AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and -// AF_INET6 addresses. +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. // // strict indicates whether addresses with the AF_UNSPEC family are accepted of not. // -// AddressAndFamily returns an address, its family. +// AddressAndFamily returns an address and its family. func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { @@ -308,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -359,6 +377,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } return out, family, nil + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(b/129292371): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + case linux.AF_UNSPEC: return tcpip.FullAddress{}, family, nil @@ -412,17 +446,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -431,11 +508,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -469,6 +541,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -646,7 +790,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // tcpip.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -719,7 +863,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return getSockOptIPv6(t, ep, name, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen) + return getSockOptIP(t, ep, name, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -777,8 +921,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -793,8 +937,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -828,6 +972,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -900,8 +1057,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - var v tcpip.DelayOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -970,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1018,6 +1187,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1042,6 +1223,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + // Length handling for parity with Linux. + if outLen == 0 { + return make([]byte, 0), nil + } + var v tcpip.IPv6TrafficClassOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + uintv := uint32(v) + // Linux truncates the output binary to outLen. + ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + // Handle cases where outLen is lesser than sizeOfInt32. + if len(ib) > outLen { + ib = ib[:outLen] + } + return ib, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1049,8 +1249,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { switch name { + case linux.IP_TTL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TTLOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // Fill in the default value, if needed. + if v == 0 { + v = DefaultTTL + } + + return int32(v), nil + case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1092,6 +1309,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac } return int32(0), nil + case linux.IP_TOS: + // Length handling for parity with Linux. + if outLen == 0 { + return []byte(nil), nil + } + var v tcpip.IPv4TOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if outLen < sizeOfInt32 { + return uint8(v), nil + } + return int32(v), nil + default: emitUnimplementedEventIP(t, name) } @@ -1102,7 +1333,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac // tcpip.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -1165,7 +1396,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1173,7 +1404,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1191,6 +1422,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -1285,11 +1523,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o tcpip.DelayOption + var o int if v == 0 { o = 1 } - return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1337,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { @@ -1344,6 +1593,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1383,6 +1640,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < -1 || v > 255 { + return syserr.ErrInvalidArgument + } + if v == -1 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + default: emitUnimplementedEventIPv6(t, name) } @@ -1514,6 +1784,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) return syserr.ErrInvalidArgument + case linux.IP_TTL: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + // -1 means default TTL. + if v == -1 { + v = 0 + } else if v < 1 || v > 255 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + + case linux.IP_TOS: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1537,9 +1831,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, - linux.IP_TOS, linux.IP_TRANSPARENT, - linux.IP_TTL, linux.IP_UNBLOCK_SOURCE, linux.IP_UNICAST_IF, linux.IP_XFRM_POLICY, @@ -1724,12 +2016,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(2 + l) } return &out, uint32(3 + l) + case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInetSize) + case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1745,7 +2039,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(b/129292371): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + default: return nil, 0 } @@ -2060,7 +2364,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2085,7 +2389,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2101,7 +2405,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { @@ -2207,9 +2511,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { @@ -2404,7 +2708,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Flag values and meanings are described in greater detail in netdevice(7) in // the SIOCGIFFLAGS section. func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) { - // epsocket should only ever be passed an epsocket.Stack. + // We should only ever be passed a netstack.Stack. epstack, ok := stack.(*Stack) if !ok { return 0, errStackType diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go index 421f93dc4..2d2c1ba2a 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "syscall" @@ -62,10 +62,14 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: + // TODO(b/142504697): "In order to create a raw socket, a + // process must have the CAP_NET_RAW capability in the user + // namespace that governs its network namespace." - raw(7) + // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, true, syserr.ErrPermissionDenied + return 0, true, syserr.ErrNotPermitted } switch protocol { @@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in return 0, true, syserr.ErrProtocolNotSupported } -// Socket creates a new socket object for the AF_INET or AF_INET6 family. +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() @@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, nil } + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocket(t, eps, stype, protocol) + } + // Figure out the transport protocol. transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { @@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return New(t, p.family, stype, int(transProto), wq, ep) } +func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // TODO(b/142504697): "In order to create a packet socket, a process + // must have the CAP_NET_RAW capability in the user namespace that + // governs its network namespace." - packet(7) + + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return New(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } -// init registers socket providers for AF_INET and AF_INET6. +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. func init() { // Providers backed by netstack. p := []provider{ @@ -138,6 +184,9 @@ func init() { family: linux.AF_INET6, netProto: ipv6.ProtocolNumber, }, + { + family: linux.AF_PACKET, + }, } for i := range p { diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go index f7b8c10cc..c7aaf722a 100644 --- a/pkg/sentry/socket/epsocket/save_restore.go +++ b/pkg/sentry/socket/netstack/save_restore.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go index 7cf7ff735..a0db2d4fd 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 5061dcbde..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -36,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], @@ -49,6 +51,14 @@ proto_library( ], ) +cc_proto_library( + name = "syscall_rpc_cc_proto", + visibility = [ + "//visibility:public", + ], + deps = [":syscall_rpc_proto"], +) + go_proto_library( name = "syscall_rpc_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto index 9586f5923..b677e9eb3 100644 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto @@ -3,7 +3,6 @@ syntax = "proto3"; // package syscall_rpc is a set of networking related system calls that can be // forwarded to a socket gofer. // -// TODO(b/77963526): Document individual RPCs. package syscall_rpc; message SendmsgRequest { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..2389a9cdb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -43,6 +43,11 @@ type ControlMessages struct { IP tcpip.ControlMessages } +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release() { + c.Unix.Release() +} + // Socket is the interface containing socket syscalls used by the syscall layer // to redirect them to the appropriate implementation. type Socket interface { diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index da9977fde..5b6a154f6 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "unix", srcs = [ @@ -24,7 +24,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", + "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 0b0240336..788ad70d2 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") +package(licenses = ["notice"]) + go_template_instance( name = "transport_message_list", out = "transport_message_list.go", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 4bd15808a..dea11e253 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -220,6 +220,11 @@ func (e *connectionedEndpoint) Close() { case e.Connected(): e.connected.CloseSend() e.receiver.CloseRecv() + // Still have unread data? If yes, we set this into the write + // end so that the peer can get ECONNRESET) when it does read. + if e.receiver.RecvQueuedSize() > 0 { + e.connected.CloseUnread() + } c = e.connected r = e.receiver e.connected = nil diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 0415fae9a..e27b1c714 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -33,6 +33,7 @@ type queue struct { mu sync.Mutex `state:"nosave"` closed bool + unread bool used int64 limit int64 dataList messageList @@ -161,6 +162,9 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) { err := syserr.ErrWouldBlock if q.closed { err = syserr.ErrClosedForReceive + if q.unread { + err = syserr.ErrConnectionReset + } } q.mu.Unlock() @@ -188,7 +192,9 @@ func (q *queue) Peek() (*message, *syserr.Error) { if q.dataList.Front() == nil { err := syserr.ErrWouldBlock if q.closed { - err = syserr.ErrClosedForReceive + if err = syserr.ErrClosedForReceive; q.unread { + err = syserr.ErrConnectionReset + } } return nil, err } @@ -208,3 +214,11 @@ func (q *queue) QueuedSize() int64 { func (q *queue) MaxQueueSize() int64 { return q.limit } + +// CloseUnread sets flag to indicate that the peer is closed (not shutdown) +// with unread data. So if read on this queue shall return ECONNRESET error. +func (q *queue) CloseUnread() { + q.mu.Lock() + defer q.mu.Unlock() + q.unread = true +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2b0ad6395..529a7a7a9 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -604,6 +608,10 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. Release() + + // CloseUnread sets the fact that this end is closed with unread data to + // the peer socket. + CloseUnread() } // +stateify savable @@ -707,6 +715,11 @@ func (e *connectedEndpoint) Release() { e.writeQueue.DecRef() } +// CloseUnread implements ConnectedEndpoint.CloseUnread. +func (e *connectedEndpoint) CloseUnread() { + e.writeQueue.CloseUnread() +} + // baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless // unix domain socket Endpoint implementations. // @@ -838,6 +851,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -853,65 +870,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0d0cb68df..885758054 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -31,7 +31,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -40,8 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SocketOperations is a Unix socket. It is similar to an epsocket, except it -// is backed by a transport.Endpoint instead of a tcpip.Endpoint. +// SocketOperations is a Unix socket. It is similar to a netstack socket, +// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint. // // +stateify savable type SocketOperations struct { @@ -116,8 +116,11 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } @@ -143,7 +146,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -155,19 +158,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - return epsocket.Ioctl(ctx, s.ep, io, args) + return netstack.Ioctl(ctx, s.ep, io, args) } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -474,13 +477,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) { // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { - return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal) + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { - f, err := epsocket.ConvertShutdown(how) + f, err := netstack.ConvertShutdown(how) if err != nil { return err } @@ -546,7 +549,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -581,7 +584,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -595,7 +598,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { + streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed { if total > 0 { err = nil } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 445d25010..d46421199 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -13,6 +14,7 @@ go_library( "open.go", "poll.go", "ptrace.go", + "select.go", "signal.go", "socket.go", "strace.go", @@ -31,8 +33,8 @@ go_library( "//pkg/sentry/arch", "//pkg/sentry/kernel", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", "//pkg/sentry/socket/netlink", + "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", ], @@ -44,6 +46,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "strace_cc_proto", + visibility = ["//visibility:public"], + deps = [":strace_proto"], +) + go_proto_library( name = "strace_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 3650fd6e1..e603f858f 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex), 21: makeSyscallInfo("access", Path, Oct), 22: makeSyscallInfo("pipe", PipeFDs), - 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval), + 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval), 24: makeSyscallInfo("sched_yield"), 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex), 26: makeSyscallInfo("msync", Hex, Hex, Hex), @@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex), 268: makeSyscallInfo("fchmodat", FD, Path, Mode), 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex), - 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex), + 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet), 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex), 272: makeSyscallInfo("unshare", CloneFlags), 273: makeSyscallInfo("set_robust_list", Hex, Hex), @@ -335,4 +335,33 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 318: makeSyscallInfo("getrandom", Hex, Hex, Hex), + 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close. + 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex), + 321: makeSyscallInfo("bpf", Hex, Hex, Hex), + 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex), + 323: makeSyscallInfo("userfaultfd", Hex), + 324: makeSyscallInfo("membarrier", Hex, Hex), + 325: makeSyscallInfo("mlock2", Hex, Hex, Hex), + 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex), + 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex), + 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex), + 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex), + 330: makeSyscallInfo("pkey_alloc", Hex, Hex), + 331: makeSyscallInfo("pkey_free", Hex), + 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), + 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet), + 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex), + 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex), + 425: makeSyscallInfo("io_uring_setup", Hex, Hex), + 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex), + 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex), + 428: makeSyscallInfo("open_tree", FD, Path, Hex), + 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex), + 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close. + 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex), + 432: makeSyscallInfo("fsmount", FD, Hex, Hex), + 433: makeSyscallInfo("fspick", FD, Path, Hex), + 434: makeSyscallInfo("pidfd_open", Hex, Hex), + 435: makeSyscallInfo("clone3", Hex, Hex), } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go new file mode 100644 index 000000000..dea309fda --- /dev/null +++ b/pkg/sentry/strace/select.go @@ -0,0 +1,53 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func fdsFromSet(t *kernel.Task, set []byte) []int { + var fds []int + // Append n if the n-th bit is 1. + for i, v := range set { + for j := 0; j < 8; j++ { + if (v>>j)&1 == 1 { + fds = append(fds, i*8+j) + } + } + } + return fds +} + +func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { + if addr == 0 { + return "null" + } + + // Calculate the size of the fd set (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 + + set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte) + if err != nil { + return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err) + } + + return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set)) +} diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f779186ad..51f2efb39 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" ) @@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) i += linux.SizeOfControlMessageHeader width := t.Arch().Width() length := int(h.Length) - linux.SizeOfControlMessageHeader + if length < 0 { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) @@ -332,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */) + fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 311389547..629c1f308 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case SelectFDSet: + output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto index 4b2f73a5f..906c52c51 100644 --- a/pkg/sentry/strace/strace.proto +++ b/pkg/sentry/strace/strace.proto @@ -32,8 +32,7 @@ message Strace { } } -message StraceEnter { -} +message StraceEnter {} message StraceExit { // Return value formatted as string. diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 3c389d375..e5d486c4e 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -206,6 +206,10 @@ const ( // PollFDs is an array of struct pollfd. The number of entries in the // array is in the next argument. PollFDs + + // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of + // fds represented must be the first argument. + SelectFDSet ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 33a40b9c6..6766ba587 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -1,13 +1,15 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "linux", srcs = [ "error.go", "flags.go", "linux64.go", + "linux64_amd64.go", + "linux64_arm64.go", "sigset.go", "sys_aio.go", "sys_capability.go", @@ -47,6 +49,7 @@ go_library( "sys_tls.go", "sys_utsname.go", "sys_write.go", + "sys_xattr.go", "timespec.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", @@ -74,8 +77,10 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", + "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", + "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/safemem", diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go index 444f2b004..07961dad9 100644 --- a/pkg/sentry/syscalls/linux/flags.go +++ b/pkg/sentry/syscalls/linux/flags.go @@ -50,5 +50,6 @@ func linuxToFlags(mask uint) fs.FileFlags { Directory: mask&linux.O_DIRECTORY != 0, Async: mask&linux.O_ASYNC != 0, LargeFile: mask&linux.O_LARGEFILE != 0, + Truncate: mask&linux.O_TRUNC != 0, } } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..68589a377 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,377 +15,13 @@ // Package linux provides syscall tables for amd64 Linux. package linux -import ( - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// AUDIT_ARCH_X86_64 identifies the Linux syscall API on AMD64, and is taken -// from <linux/audit.h>. -const _AUDIT_ARCH_X86_64 = 0xc000003e +const ( + // LinuxSysname is the OS name advertised by gVisor. + LinuxSysname = "Linux" -// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall -// numbers from Linux 4.4. -var AMD64 = &kernel.SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Version: kernel.Version{ - // Version 4.4 is chosen as a stable, longterm version of Linux, which - // guides the interface provided by this syscall table. The build - // version is that for a clean build with default kernel config, at 5 - // minutes after v4.4 was tagged. - Sysname: "Linux", - Release: "4.4", - Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016", - }, - AuditNumber: _AUDIT_ARCH_X86_64, - Table: map[uintptr]kernel.Syscall{ - 0: syscalls.Supported("read", Read), - 1: syscalls.Supported("write", Write), - 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), - 3: syscalls.Supported("close", Close), - 4: syscalls.Supported("stat", Stat), - 5: syscalls.Supported("fstat", Fstat), - 6: syscalls.Supported("lstat", Lstat), - 7: syscalls.Supported("poll", Poll), - 8: syscalls.Supported("lseek", Lseek), - 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), - 10: syscalls.Supported("mprotect", Mprotect), - 11: syscalls.Supported("munmap", Munmap), - 12: syscalls.Supported("brk", Brk), - 13: syscalls.Supported("rt_sigaction", RtSigaction), - 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), - 15: syscalls.Supported("rt_sigreturn", RtSigreturn), - 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), - 17: syscalls.Supported("pread64", Pread64), - 18: syscalls.Supported("pwrite64", Pwrite64), - 19: syscalls.Supported("readv", Readv), - 20: syscalls.Supported("writev", Writev), - 21: syscalls.Supported("access", Access), - 22: syscalls.Supported("pipe", Pipe), - 23: syscalls.Supported("select", Select), - 24: syscalls.Supported("sched_yield", SchedYield), - 25: syscalls.Supported("mremap", Mremap), - 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), - 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), - 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), - 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), - 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), - 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), - 32: syscalls.Supported("dup", Dup), - 33: syscalls.Supported("dup2", Dup2), - 34: syscalls.Supported("pause", Pause), - 35: syscalls.Supported("nanosleep", Nanosleep), - 36: syscalls.Supported("getitimer", Getitimer), - 37: syscalls.Supported("alarm", Alarm), - 38: syscalls.Supported("setitimer", Setitimer), - 39: syscalls.Supported("getpid", Getpid), - 40: syscalls.Supported("sendfile", Sendfile), - 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), - 42: syscalls.Supported("connect", Connect), - 43: syscalls.Supported("accept", Accept), - 44: syscalls.Supported("sendto", SendTo), - 45: syscalls.Supported("recvfrom", RecvFrom), - 46: syscalls.Supported("sendmsg", SendMsg), - 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), - 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), - 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), - 50: syscalls.Supported("listen", Listen), - 51: syscalls.Supported("getsockname", GetSockName), - 52: syscalls.Supported("getpeername", GetPeerName), - 53: syscalls.Supported("socketpair", SocketPair), - 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), - 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), - 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), - 57: syscalls.Supported("fork", Fork), - 58: syscalls.Supported("vfork", Vfork), - 59: syscalls.Supported("execve", Execve), - 60: syscalls.Supported("exit", Exit), - 61: syscalls.Supported("wait4", Wait4), - 62: syscalls.Supported("kill", Kill), - 63: syscalls.Supported("uname", Uname), - 64: syscalls.Supported("semget", Semget), - 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 67: syscalls.Supported("shmdt", Shmdt), - 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), - 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), - 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), - 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), - 76: syscalls.Supported("truncate", Truncate), - 77: syscalls.Supported("ftruncate", Ftruncate), - 78: syscalls.Supported("getdents", Getdents), - 79: syscalls.Supported("getcwd", Getcwd), - 80: syscalls.Supported("chdir", Chdir), - 81: syscalls.Supported("fchdir", Fchdir), - 82: syscalls.Supported("rename", Rename), - 83: syscalls.Supported("mkdir", Mkdir), - 84: syscalls.Supported("rmdir", Rmdir), - 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), - 87: syscalls.Supported("unlink", Unlink), - 88: syscalls.Supported("symlink", Symlink), - 89: syscalls.Supported("readlink", Readlink), - 90: syscalls.Supported("chmod", Chmod), - 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), - 92: syscalls.Supported("chown", Chown), - 93: syscalls.Supported("fchown", Fchown), - 94: syscalls.Supported("lchown", Lchown), - 95: syscalls.Supported("umask", Umask), - 96: syscalls.Supported("gettimeofday", Gettimeofday), - 97: syscalls.Supported("getrlimit", Getrlimit), - 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), - 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), - 100: syscalls.Supported("times", Times), - 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), - 102: syscalls.Supported("getuid", Getuid), - 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), - 104: syscalls.Supported("getgid", Getgid), - 105: syscalls.Supported("setuid", Setuid), - 106: syscalls.Supported("setgid", Setgid), - 107: syscalls.Supported("geteuid", Geteuid), - 108: syscalls.Supported("getegid", Getegid), - 109: syscalls.Supported("setpgid", Setpgid), - 110: syscalls.Supported("getppid", Getppid), - 111: syscalls.Supported("getpgrp", Getpgrp), - 112: syscalls.Supported("setsid", Setsid), - 113: syscalls.Supported("setreuid", Setreuid), - 114: syscalls.Supported("setregid", Setregid), - 115: syscalls.Supported("getgroups", Getgroups), - 116: syscalls.Supported("setgroups", Setgroups), - 117: syscalls.Supported("setresuid", Setresuid), - 118: syscalls.Supported("getresuid", Getresuid), - 119: syscalls.Supported("setresgid", Setresgid), - 120: syscalls.Supported("getresgid", Getresgid), - 121: syscalls.Supported("getpgid", Getpgid), - 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 124: syscalls.Supported("getsid", Getsid), - 125: syscalls.Supported("capget", Capget), - 126: syscalls.Supported("capset", Capset), - 127: syscalls.Supported("rt_sigpending", RtSigpending), - 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), - 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), - 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), - 131: syscalls.Supported("sigaltstack", Sigaltstack), - 132: syscalls.Supported("utime", Utime), - 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), - 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), - 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), - 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), - 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), - 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), - 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), - 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), - 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), - 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), - 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), - 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), - 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), - 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), - 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), - 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), - 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), - 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), - 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), - 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), - 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), - 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), - 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), - 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), - 161: syscalls.Supported("chroot", Chroot), - 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), - 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), - 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), - 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), - 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), - 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), - 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), - 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), - 170: syscalls.Supported("sethostname", Sethostname), - 171: syscalls.Supported("setdomainname", Setdomainname), - 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), - 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), - 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), - 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), - 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), - 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations - 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), - 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), - 186: syscalls.Supported("gettid", Gettid), - 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) - 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 200: syscalls.Supported("tkill", Tkill), - 201: syscalls.Supported("time", Time), - 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), - 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), - 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), - 213: syscalls.Supported("epoll_create", EpollCreate), - 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), - 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), - 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), - 217: syscalls.Supported("getdents64", Getdents64), - 218: syscalls.Supported("set_tid_address", SetTidAddress), - 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) - 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), - 222: syscalls.Supported("timer_create", TimerCreate), - 223: syscalls.Supported("timer_settime", TimerSettime), - 224: syscalls.Supported("timer_gettime", TimerGettime), - 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), - 226: syscalls.Supported("timer_delete", TimerDelete), - 227: syscalls.Supported("clock_settime", ClockSettime), - 228: syscalls.Supported("clock_gettime", ClockGettime), - 229: syscalls.Supported("clock_getres", ClockGetres), - 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), - 231: syscalls.Supported("exit_group", ExitGroup), - 232: syscalls.Supported("epoll_wait", EpollWait), - 233: syscalls.Supported("epoll_ctl", EpollCtl), - 234: syscalls.Supported("tgkill", Tgkill), - 235: syscalls.Supported("utimes", Utimes), - 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), - 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), - 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), - 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), - 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), - 247: syscalls.Supported("waitid", Waitid), - 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), - 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), - 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), - 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), - 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), - 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), - 257: syscalls.Supported("openat", Openat), - 258: syscalls.Supported("mkdirat", Mkdirat), - 259: syscalls.Supported("mknodat", Mknodat), - 260: syscalls.Supported("fchownat", Fchownat), - 261: syscalls.Supported("futimesat", Futimesat), - 262: syscalls.Supported("fstatat", Fstatat), - 263: syscalls.Supported("unlinkat", Unlinkat), - 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), - 266: syscalls.Supported("symlinkat", Symlinkat), - 267: syscalls.Supported("readlinkat", Readlinkat), - 268: syscalls.Supported("fchmodat", Fchmodat), - 269: syscalls.Supported("faccessat", Faccessat), - 270: syscalls.Supported("pselect", Pselect), - 271: syscalls.Supported("ppoll", Ppoll), - 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), - 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) - 280: syscalls.Supported("utimensat", Utimensat), - 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) - 283: syscalls.Supported("timerfd_create", TimerfdCreate), - 284: syscalls.Supported("eventfd", Eventfd), - 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), - 286: syscalls.Supported("timerfd_settime", TimerfdSettime), - 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), - 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) - 290: syscalls.Supported("eventfd2", Eventfd2), - 291: syscalls.Supported("epoll_create1", EpollCreate1), - 292: syscalls.Supported("dup3", Dup3), - 293: syscalls.Supported("pipe2", Pipe2), - 294: syscalls.Supported("inotify_init1", InotifyInit1), - 295: syscalls.Supported("preadv", Preadv), - 296: syscalls.Supported("pwritev", Pwritev), - 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), - 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), - 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), - 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 302: syscalls.Supported("prlimit64", Prlimit64), - 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), - 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), - 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), - 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) - 309: syscalls.Supported("getcpu", Getcpu), - 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), - 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), - 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) - 317: syscalls.Supported("seccomp", Seccomp), - 318: syscalls.Supported("getrandom", GetRandom), - 319: syscalls.Supported("memfd_create", MemfdCreate), - 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), - 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) - 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) - 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + // LinuxRelease is the Linux release version number advertised by gVisor. + LinuxRelease = "4.4.0" - // Syscalls after 325 are "backports" from versions of Linux after 4.4. - 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), - 327: syscalls.Supported("preadv2", Preadv2), - 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), - 332: syscalls.Supported("statx", Statx), - }, - - Emulate: map[usermem.Addr]uintptr{ - 0xffffffffff600000: 96, // vsyscall gettimeofday(2) - 0xffffffffff600400: 201, // vsyscall time(2) - 0xffffffffff600800: 309, // vsyscall getcpu(2) - }, - Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { - t.Kernel().EmitUnimplementedEvent(t) - return 0, syserror.ENOSYS - }, -} + // LinuxVersion is the version info advertised by gVisor. + LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016" +) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go new file mode 100644 index 000000000..272ae9991 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -0,0 +1,406 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var AMD64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.AMD64, + Version: kernel.Version{ + // Version 4.4 is chosen as a stable, longterm version of Linux, which + // guides the interface provided by this syscall table. The build + // version is that for a clean build with default kernel config, at 5 + // minutes after v4.4 was tagged. + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, + }, + AuditNumber: linux.AUDIT_ARCH_X86_64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.Supported("read", Read), + 1: syscalls.Supported("write", Write), + 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), + 3: syscalls.Supported("close", Close), + 4: syscalls.Supported("stat", Stat), + 5: syscalls.Supported("fstat", Fstat), + 6: syscalls.Supported("lstat", Lstat), + 7: syscalls.Supported("poll", Poll), + 8: syscalls.Supported("lseek", Lseek), + 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), + 10: syscalls.Supported("mprotect", Mprotect), + 11: syscalls.Supported("munmap", Munmap), + 12: syscalls.Supported("brk", Brk), + 13: syscalls.Supported("rt_sigaction", RtSigaction), + 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 15: syscalls.Supported("rt_sigreturn", RtSigreturn), + 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 17: syscalls.Supported("pread64", Pread64), + 18: syscalls.Supported("pwrite64", Pwrite64), + 19: syscalls.Supported("readv", Readv), + 20: syscalls.Supported("writev", Writev), + 21: syscalls.Supported("access", Access), + 22: syscalls.Supported("pipe", Pipe), + 23: syscalls.Supported("select", Select), + 24: syscalls.Supported("sched_yield", SchedYield), + 25: syscalls.Supported("mremap", Mremap), + 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 32: syscalls.Supported("dup", Dup), + 33: syscalls.Supported("dup2", Dup2), + 34: syscalls.Supported("pause", Pause), + 35: syscalls.Supported("nanosleep", Nanosleep), + 36: syscalls.Supported("getitimer", Getitimer), + 37: syscalls.Supported("alarm", Alarm), + 38: syscalls.Supported("setitimer", Setitimer), + 39: syscalls.Supported("getpid", Getpid), + 40: syscalls.Supported("sendfile", Sendfile), + 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 42: syscalls.Supported("connect", Connect), + 43: syscalls.Supported("accept", Accept), + 44: syscalls.Supported("sendto", SendTo), + 45: syscalls.Supported("recvfrom", RecvFrom), + 46: syscalls.Supported("sendmsg", SendMsg), + 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 50: syscalls.Supported("listen", Listen), + 51: syscalls.Supported("getsockname", GetSockName), + 52: syscalls.Supported("getpeername", GetPeerName), + 53: syscalls.Supported("socketpair", SocketPair), + 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 57: syscalls.Supported("fork", Fork), + 58: syscalls.Supported("vfork", Vfork), + 59: syscalls.Supported("execve", Execve), + 60: syscalls.Supported("exit", Exit), + 61: syscalls.Supported("wait4", Wait4), + 62: syscalls.Supported("kill", Kill), + 63: syscalls.Supported("uname", Uname), + 64: syscalls.Supported("semget", Semget), + 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 67: syscalls.Supported("shmdt", Shmdt), + 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), + 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 76: syscalls.Supported("truncate", Truncate), + 77: syscalls.Supported("ftruncate", Ftruncate), + 78: syscalls.Supported("getdents", Getdents), + 79: syscalls.Supported("getcwd", Getcwd), + 80: syscalls.Supported("chdir", Chdir), + 81: syscalls.Supported("fchdir", Fchdir), + 82: syscalls.Supported("rename", Rename), + 83: syscalls.Supported("mkdir", Mkdir), + 84: syscalls.Supported("rmdir", Rmdir), + 85: syscalls.Supported("creat", Creat), + 86: syscalls.Supported("link", Link), + 87: syscalls.Supported("unlink", Unlink), + 88: syscalls.Supported("symlink", Symlink), + 89: syscalls.Supported("readlink", Readlink), + 90: syscalls.Supported("chmod", Chmod), + 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 92: syscalls.Supported("chown", Chown), + 93: syscalls.Supported("fchown", Fchown), + 94: syscalls.Supported("lchown", Lchown), + 95: syscalls.Supported("umask", Umask), + 96: syscalls.Supported("gettimeofday", Gettimeofday), + 97: syscalls.Supported("getrlimit", Getrlimit), + 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 100: syscalls.Supported("times", Times), + 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 102: syscalls.Supported("getuid", Getuid), + 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 104: syscalls.Supported("getgid", Getgid), + 105: syscalls.Supported("setuid", Setuid), + 106: syscalls.Supported("setgid", Setgid), + 107: syscalls.Supported("geteuid", Geteuid), + 108: syscalls.Supported("getegid", Getegid), + 109: syscalls.Supported("setpgid", Setpgid), + 110: syscalls.Supported("getppid", Getppid), + 111: syscalls.Supported("getpgrp", Getpgrp), + 112: syscalls.Supported("setsid", Setsid), + 113: syscalls.Supported("setreuid", Setreuid), + 114: syscalls.Supported("setregid", Setregid), + 115: syscalls.Supported("getgroups", Getgroups), + 116: syscalls.Supported("setgroups", Setgroups), + 117: syscalls.Supported("setresuid", Setresuid), + 118: syscalls.Supported("getresuid", Getresuid), + 119: syscalls.Supported("setresgid", Setresgid), + 120: syscalls.Supported("getresgid", Getresgid), + 121: syscalls.Supported("getpgid", Getpgid), + 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 124: syscalls.Supported("getsid", Getsid), + 125: syscalls.Supported("capget", Capget), + 126: syscalls.Supported("capset", Capset), + 127: syscalls.Supported("rt_sigpending", RtSigpending), + 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 131: syscalls.Supported("sigaltstack", Sigaltstack), + 132: syscalls.Supported("utime", Utime), + 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), + 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), + 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), + 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), + 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), + 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), + 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), + 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), + 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 161: syscalls.Supported("chroot", Chroot), + 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 170: syscalls.Supported("sethostname", Sethostname), + 171: syscalls.Supported("setdomainname", Setdomainname), + 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), + 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), + 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), + 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), + 186: syscalls.Supported("gettid", Gettid), + 187: syscalls.Supported("readahead", Readahead), + 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), + 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), + 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 200: syscalls.Supported("tkill", Tkill), + 201: syscalls.Supported("time", Time), + 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 213: syscalls.Supported("epoll_create", EpollCreate), + 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), + 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), + 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 217: syscalls.Supported("getdents64", Getdents64), + 218: syscalls.Supported("set_tid_address", SetTidAddress), + 219: syscalls.Supported("restart_syscall", RestartSyscall), + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), + 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), + 222: syscalls.Supported("timer_create", TimerCreate), + 223: syscalls.Supported("timer_settime", TimerSettime), + 224: syscalls.Supported("timer_gettime", TimerGettime), + 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 226: syscalls.Supported("timer_delete", TimerDelete), + 227: syscalls.Supported("clock_settime", ClockSettime), + 228: syscalls.Supported("clock_gettime", ClockGettime), + 229: syscalls.Supported("clock_getres", ClockGetres), + 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 231: syscalls.Supported("exit_group", ExitGroup), + 232: syscalls.Supported("epoll_wait", EpollWait), + 233: syscalls.Supported("epoll_ctl", EpollCtl), + 234: syscalls.Supported("tgkill", Tgkill), + 235: syscalls.Supported("utimes", Utimes), + 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), + 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 247: syscalls.Supported("waitid", Waitid), + 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), + 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 257: syscalls.Supported("openat", Openat), + 258: syscalls.Supported("mkdirat", Mkdirat), + 259: syscalls.Supported("mknodat", Mknodat), + 260: syscalls.Supported("fchownat", Fchownat), + 261: syscalls.Supported("futimesat", Futimesat), + 262: syscalls.Supported("fstatat", Fstatat), + 263: syscalls.Supported("unlinkat", Unlinkat), + 264: syscalls.Supported("renameat", Renameat), + 265: syscalls.Supported("linkat", Linkat), + 266: syscalls.Supported("symlinkat", Symlinkat), + 267: syscalls.Supported("readlinkat", Readlinkat), + 268: syscalls.Supported("fchmodat", Fchmodat), + 269: syscalls.Supported("faccessat", Faccessat), + 270: syscalls.Supported("pselect", Pselect), + 271: syscalls.Supported("ppoll", Ppoll), + 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), + 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 280: syscalls.Supported("utimensat", Utimensat), + 281: syscalls.Supported("epoll_pwait", EpollPwait), + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 283: syscalls.Supported("timerfd_create", TimerfdCreate), + 284: syscalls.Supported("eventfd", Eventfd), + 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 286: syscalls.Supported("timerfd_settime", TimerfdSettime), + 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 288: syscalls.Supported("accept4", Accept4), + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 290: syscalls.Supported("eventfd2", Eventfd2), + 291: syscalls.Supported("epoll_create1", EpollCreate1), + 292: syscalls.Supported("dup3", Dup3), + 293: syscalls.Supported("pipe2", Pipe2), + 294: syscalls.Supported("inotify_init1", InotifyInit1), + 295: syscalls.Supported("preadv", Preadv), + 296: syscalls.Supported("pwritev", Pwritev), + 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 302: syscalls.Supported("prlimit64", Prlimit64), + 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 309: syscalls.Supported("getcpu", Getcpu), + 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 317: syscalls.Supported("seccomp", Seccomp), + 318: syscalls.Supported("getrandom", GetRandom), + 319: syscalls.Supported("memfd_create", MemfdCreate), + 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), + 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 322: syscalls.Supported("execveat", Execveat), + 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + + // Syscalls implemented after 325 are "backports" from versions + // of Linux after 4.4. + 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 327: syscalls.Supported("preadv2", Preadv2), + 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), + 332: syscalls.Supported("statx", Statx), + 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + }, + + Emulate: map[usermem.Addr]uintptr{ + 0xffffffffff600000: 96, // vsyscall gettimeofday(2) + 0xffffffffff600400: 201, // vsyscall time(2) + 0xffffffffff600800: 309, // vsyscall getcpu(2) + }, + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go new file mode 100644 index 000000000..3b584eed9 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -0,0 +1,332 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var ARM64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.ARM64, + Version: kernel.Version{ + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, + }, + AuditNumber: linux.AUDIT_ARCH_AARCH64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), + 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), + 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 17: syscalls.Supported("getcwd", Getcwd), + 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 19: syscalls.Supported("eventfd2", Eventfd2), + 20: syscalls.Supported("epoll_create1", EpollCreate1), + 21: syscalls.Supported("epoll_ctl", EpollCtl), + 22: syscalls.Supported("epoll_pwait", EpollPwait), + 23: syscalls.Supported("dup", Dup), + 24: syscalls.Supported("dup3", Dup3), + 26: syscalls.Supported("inotify_init1", InotifyInit1), + 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 33: syscalls.Supported("mknodat", Mknodat), + 34: syscalls.Supported("mkdirat", Mkdirat), + 35: syscalls.Supported("unlinkat", Unlinkat), + 36: syscalls.Supported("symlinkat", Symlinkat), + 37: syscalls.Supported("linkat", Linkat), + 38: syscalls.Supported("renameat", Renameat), + 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 46: syscalls.Supported("ftruncate", Ftruncate), + 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 48: syscalls.Supported("faccessat", Faccessat), + 49: syscalls.Supported("chdir", Chdir), + 50: syscalls.Supported("fchdir", Fchdir), + 51: syscalls.Supported("chroot", Chroot), + 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 53: syscalls.Supported("fchmodat", Fchmodat), + 54: syscalls.Supported("fchownat", Fchownat), + 55: syscalls.Supported("fchown", Fchown), + 56: syscalls.Supported("openat", Openat), + 57: syscalls.Supported("close", Close), + 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 59: syscalls.Supported("pipe2", Pipe2), + 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 61: syscalls.Supported("getdents64", Getdents64), + 62: syscalls.Supported("lseek", Lseek), + 63: syscalls.Supported("read", Read), + 64: syscalls.Supported("write", Write), + 65: syscalls.Supported("readv", Readv), + 66: syscalls.Supported("writev", Writev), + 67: syscalls.Supported("pread64", Pread64), + 68: syscalls.Supported("pwrite64", Pwrite64), + 69: syscalls.Supported("preadv", Preadv), + 70: syscalls.Supported("pwritev", Pwritev), + 71: syscalls.Supported("sendfile", Sendfile), + 72: syscalls.Supported("pselect", Pselect), + 73: syscalls.Supported("ppoll", Ppoll), + 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 77: syscalls.Supported("tee", Tee), + 78: syscalls.Supported("readlinkat", Readlinkat), + 80: syscalls.Supported("fstat", Fstat), + 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 85: syscalls.Supported("timerfd_create", TimerfdCreate), + 86: syscalls.Supported("timerfd_settime", TimerfdSettime), + 87: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 88: syscalls.Supported("utimensat", Utimensat), + 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 90: syscalls.Supported("capget", Capget), + 91: syscalls.Supported("capset", Capset), + 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 93: syscalls.Supported("exit", Exit), + 94: syscalls.Supported("exit_group", ExitGroup), + 95: syscalls.Supported("waitid", Waitid), + 96: syscalls.Supported("set_tid_address", SetTidAddress), + 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 101: syscalls.Supported("nanosleep", Nanosleep), + 102: syscalls.Supported("getitimer", Getitimer), + 103: syscalls.Supported("setitimer", Setitimer), + 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 107: syscalls.Supported("timer_create", TimerCreate), + 108: syscalls.Supported("timer_gettime", TimerGettime), + 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 110: syscalls.Supported("timer_settime", TimerSettime), + 111: syscalls.Supported("timer_delete", TimerDelete), + 112: syscalls.Supported("clock_settime", ClockSettime), + 113: syscalls.Supported("clock_gettime", ClockGettime), + 114: syscalls.Supported("clock_getres", ClockGetres), + 115: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 124: syscalls.Supported("sched_yield", SchedYield), + 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 128: syscalls.Supported("restart_syscall", RestartSyscall), + 129: syscalls.Supported("kill", Kill), + 130: syscalls.Supported("tkill", Tkill), + 131: syscalls.Supported("tgkill", Tgkill), + 132: syscalls.Supported("sigaltstack", Sigaltstack), + 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 134: syscalls.Supported("rt_sigaction", RtSigaction), + 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 136: syscalls.Supported("rt_sigpending", RtSigpending), + 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 139: syscalls.Supported("rt_sigreturn", RtSigreturn), + 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 143: syscalls.Supported("setregid", Setregid), + 144: syscalls.Supported("setgid", Setgid), + 145: syscalls.Supported("setreuid", Setreuid), + 146: syscalls.Supported("setuid", Setuid), + 147: syscalls.Supported("setresuid", Setresuid), + 148: syscalls.Supported("getresuid", Getresuid), + 149: syscalls.Supported("setresgid", Setresgid), + 150: syscalls.Supported("getresgid", Getresgid), + 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 153: syscalls.Supported("times", Times), + 154: syscalls.Supported("setpgid", Setpgid), + 155: syscalls.Supported("getpgid", Getpgid), + 156: syscalls.Supported("getsid", Getsid), + 157: syscalls.Supported("setsid", Setsid), + 158: syscalls.Supported("getgroups", Getgroups), + 159: syscalls.Supported("setgroups", Setgroups), + 160: syscalls.Supported("uname", Uname), + 161: syscalls.Supported("sethostname", Sethostname), + 162: syscalls.Supported("setdomainname", Setdomainname), + 163: syscalls.Supported("getrlimit", Getrlimit), + 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 166: syscalls.Supported("umask", Umask), + 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 168: syscalls.Supported("getcpu", Getcpu), + 169: syscalls.Supported("gettimeofday", Gettimeofday), + 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 172: syscalls.Supported("getpid", Getpid), + 173: syscalls.Supported("getppid", Getppid), + 174: syscalls.Supported("getuid", Getuid), + 175: syscalls.Supported("geteuid", Geteuid), + 176: syscalls.Supported("getgid", Getgid), + 177: syscalls.Supported("getegid", Getegid), + 178: syscalls.Supported("gettid", Gettid), + 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 190: syscalls.Supported("semget", Semget), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), + 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 197: syscalls.Supported("shmdt", Shmdt), + 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 199: syscalls.Supported("socketpair", SocketPair), + 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 201: syscalls.Supported("listen", Listen), + 202: syscalls.Supported("accept", Accept), + 203: syscalls.Supported("connect", Connect), + 204: syscalls.Supported("getsockname", GetSockName), + 205: syscalls.Supported("getpeername", GetPeerName), + 206: syscalls.Supported("sendto", SendTo), + 207: syscalls.Supported("recvfrom", RecvFrom), + 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 211: syscalls.Supported("sendmsg", SendMsg), + 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 213: syscalls.Supported("readahead", Readahead), + 214: syscalls.Supported("brk", Brk), + 215: syscalls.Supported("munmap", Munmap), + 216: syscalls.Supported("mremap", Mremap), + 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 221: syscalls.Supported("execve", Execve), + 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 226: syscalls.Supported("mprotect", Mprotect), + 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 242: syscalls.Supported("accept4", Accept4), + 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 260: syscalls.Supported("wait4", Wait4), + 261: syscalls.Supported("prlimit64", Prlimit64), + 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 277: syscalls.Supported("seccomp", Seccomp), + 278: syscalls.Supported("getrandom", GetRandom), + 279: syscalls.Supported("memfd_create", MemfdCreate), + 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 281: syscalls.Supported("execveat", Execveat), + 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 286: syscalls.Supported("preadv2", Preadv2), + 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), + 291: syscalls.Supported("statx", Statx), + 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + }, + Emulate: map[usermem.Addr]uintptr{}, + + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 2e00a91ce..9bc2445a5 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -169,10 +169,14 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint if dirPath { return syserror.ENOTDIR } - if flags&linux.O_TRUNC != 0 { - if err := d.Inode.Truncate(t, d, 0); err != nil { - return err - } + } + + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. + if flags&linux.O_TRUNC != 0 { + if err := d.Inode.Truncate(t, d, 0); err != nil { + return err } } @@ -396,7 +400,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l return err } - // Should we truncate the file? + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. if flags&linux.O_TRUNC != 0 { if err := found.Inode.Truncate(t, found, 0); err != nil { return err @@ -834,25 +840,42 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(newfd), nil, nil } -func fGetOwn(t *kernel.Task, file *fs.File) int32 { +func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx { ma := file.Async(nil) if ma == nil { - return 0 + return linux.FOwnerEx{} } a := ma.(*fasync.FileAsync) ot, otg, opg := a.Owner() switch { case ot != nil: - return int32(t.PIDNamespace().IDOfTask(ot)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + } case otg != nil: - return int32(t.PIDNamespace().IDOfThreadGroup(otg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + } case opg != nil: - return int32(-t.PIDNamespace().IDOfProcessGroup(opg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + } default: - return 0 + return linux.FOwnerEx{} } } +func fGetOwn(t *kernel.Task, file *fs.File) int32 { + owner := fGetOwnEx(t, file) + if owner.Type == linux.F_OWNER_PGRP { + return -owner.PID + } + return owner.PID +} + // fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux. // // If who is positive, it represents a PID. If negative, it represents a PGID. @@ -895,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.FDTable().SetFlags(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) + return 0, nil, nil case linux.F_GETFL: return uintptr(file.Flags().ToLinux()), nil, nil case linux.F_SETFL: flags := uint(args[2].Uint()) file.SetFlags(linuxToFlags(flags).Settable()) + return 0, nil, nil case linux.F_SETLK, linux.F_SETLKW: // In Linux the file system can choose to provide lock operations for an inode. // Normally pipe and socket types lack lock operations. We diverge and use a heavy @@ -1002,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETOWN: fSetOwn(t, file, args[2].Int()) return 0, nil, nil + case linux.F_GETOWN_EX: + addr := args[2].Pointer() + owner := fGetOwnEx(t, file) + _, err := t.CopyOut(addr, &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + addr := args[2].Pointer() + var owner linux.FOwnerEx + n, err := t.CopyIn(addr, &owner) + if err != nil { + return 0, nil, err + } + a := file.Async(fasync.New).(*fasync.FileAsync) + switch owner.Type { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) + if task == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerTask(t, task) + return uintptr(n), nil, nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) + if tg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return uintptr(n), nil, nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) + if pg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return uintptr(n), nil, nil + default: + return 0, nil, syserror.EINVAL + } case linux.F_GET_SEALS: val, err := tmpfs.GetSeals(file.Dirent.Inode) return uintptr(val), nil, err @@ -1029,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Everything else is not yet supported. return 0, nil, syserror.EINVAL } - return 0, nil, nil } const ( @@ -1423,9 +1485,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { if err != nil { return err } - if dirPath { - return syserror.ENOENT - } return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { if !fs.IsDir(d.Inode.StableAttr) { @@ -1436,7 +1495,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return err } - return d.Remove(t, root, name) + return d.Remove(t, root, name, dirPath) }) } @@ -1486,6 +1545,8 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if fs.IsDir(d.Inode.StableAttr) { return syserror.EISDIR } + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. if !fs.IsFile(d.Inode.StableAttr) { return syserror.EINVAL } @@ -1524,7 +1585,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Note that this is different from truncate(2) above, where a + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. Note that this is different from truncate(2) above, where a // directory returns EISDIR. if !fs.IsFile(file.Dirent.Inode.StableAttr) { return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 7a13beac2..2b2df989a 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) return remainingTimeout, n, err } +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } - // Capture all the provided input vectors. - // - // N.B. This only works on little-endian architectures. - byteCount := (nfds + 7) / 8 - - bitsInLastPartialByte := uint(nfds % 8) - r := make([]byte, byteCount) - w := make([]byte, byteCount) - e := make([]byte, byteCount) + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 - if readFDs != 0 { - if _, err := t.CopyIn(readFDs, &r); err != nil { - return 0, err - } - // Mask out bits above nfds. - if bitsInLastPartialByte != 0 { - r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if writeFDs != 0 { - if _, err := t.CopyIn(writeFDs, &w); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if exceptFDs != 0 { - if _, err := t.CopyIn(exceptFDs, &e); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } // Count how many FDs are actually being requested so that we can build // a PollFD array. fdCount := 0 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { v := r[i] | w[i] | e[i] for v != 0 { v &= (v - 1) @@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Build the PollFD array. pfd := make([]linux.PollFD, 0, fdCount) var fd int32 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { rV, wV, eV := r[i], w[i], e[i] v := rV | wV | eV m := byte(1) @@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add } // Do the syscall, then count the number of bits set. - _, _, err := pollBlock(t, pfd, timeout) - if err != nil { + if _, _, err = pollBlock(t, pfd, timeout); err != nil { return 0, syserror.ConvertIntr(err, syserror.EINTR) } diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 3ab54271c..cd31e0649 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file) } +// Readahead implements readahead(2). +func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + size := args[2].SizeT() + + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the file is readable. + if !file.Flags().Read { + return 0, nil, syserror.EBADF + } + + // Check that the size is valid. + if int(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the offset is legitimate. + if offset < 0 { + return 0, nil, syserror.EINVAL + } + + // Return EINVAL; if the underlying file type does not support readahead, + // then Linux will return EINVAL to indicate as much. In the future, we + // may extend this function to actually support readahead hints. + return 0, nil, syserror.EINVAL +} + // Pread64 implements linux syscall pread64(2). func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 0104a94c0..fb6efd5d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -20,7 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" + "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?") return 0, nil, syserror.EINTR } + +// sharedSignalfd is shared between the two calls. +func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { + // Copy in the signal mask. + mask, err := copyInSigSet(t, sigset, sigsetsize) + if err != nil { + return 0, nil, err + } + + // Always check for valid flags, even if not creating. + if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Is this a change to an existing signalfd? + // + // The spec indicates that this should adjust the mask. + if fd != -1 { + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Is this a signalfd? + if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { + s.SetMask(mask) + return 0, nil, nil + } + + // Not a signalfd. + return 0, nil, syserror.EINVAL + } + + // Create a new file. + file, err := signalfd.New(t, mask) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + // Set appropriate flags. + file.SetFlags(fs.SettableFileFlags{ + NonBlocking: flags&linux.SFD_NONBLOCK != 0, + }) + + // Create a new descriptor. + fd, err = t.NewFDFrom(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.SFD_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + // Done. + return uintptr(fd), nil, nil +} + +// Signalfd implements the linux syscall signalfd(2). +func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + return sharedSignalfd(t, fd, sigset, sigsetsize, 0) +} + +// Signalfd4 implements the linux syscall signalfd4(2). +func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + flags := args[3].Int() + return sharedSignalfd(t, fd, sigset, sigsetsize, flags) +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..4b5aafcc0 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -447,16 +447,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - // Read the length if present. Reject negative values. + // Read the length. Reject negative values. optLen := int32(0) - if optLenAddr != 0 { - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { - return 0, nil, err - } - - if optLen < 0 { - return 0, nil, syserror.EINVAL - } + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -465,11 +462,9 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - if optLenAddr != 0 { - vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { - return 0, nil, err - } + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err } if v != nil { @@ -531,7 +526,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { @@ -769,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Unix.Release() + cms.Release() } if int(msg.Flags) != mflags { @@ -789,24 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Unix.Release() + defer cms.Release() controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) } - if cms.IP.HasTimestamp { - controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData) - } - - if cms.IP.HasInq { - // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - controlData = control.PackInq(t, cms.IP.Inq, controlData) - } - if cms.Unix.Rights != nil { controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) } @@ -882,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Unix.Release() + cm.Release() if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -1065,7 +1052,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme } // Call the syscall implementation. - n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages}) + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release() diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 8a98fedcb..dd3a5807f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -159,12 +156,17 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) + } + + // Sendfile can't lose any data because inFD is always a regual file. + if n != 0 { + err = nil } // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -181,12 +183,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -200,6 +196,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -247,17 +250,17 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inOffset != 0 || outOffset != 0 { return 0, nil, syserror.ESPIPE } - default: - return 0, nil, syserror.EINVAL - } - // We may not refer to the same pipe; otherwise it's a continuous loop. - if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + // We may not refer to the same pipe; otherwise it's a continuous loop. + if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + return 0, nil, syserror.EINVAL + } + default: return 0, nil, syserror.EINVAL } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -275,9 +278,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -301,12 +301,20 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) + + // Tee doesn't change a state of inFD, so it can't lose any data. + if n != 0 { + err = nil + } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 8ab7ffa25..4115116ff 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -15,12 +15,15 @@ package linux import ( + "path" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -67,8 +70,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal argvAddr := args[1].Pointer() envvAddr := args[2].Pointer() - // Extract our arguments. - filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX) + return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirFD := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + + return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) if err != nil { return 0, nil, err } @@ -89,14 +106,66 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 + if !atEmptyPath && len(pathname) == 0 { + return 0, nil, syserror.ENOENT + } + resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 + root := t.FSContext().RootDirectory() defer root.DecRef() - wd := t.FSContext().WorkingDirectory() - defer wd.DecRef() + + var wd *fs.Dirent + var executable *fs.File + var closeOnExec bool + if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) { + // Even if the pathname is absolute, we may still need the wd + // for interpreter scripts if the path of the interpreter is + // relative. + wd = t.FSContext().WorkingDirectory() + } else { + // Need to extract the given FD. + f, fdFlags := t.FDTable().Get(dirFD) + if f == nil { + return 0, nil, syserror.EBADF + } + defer f.DecRef() + closeOnExec = fdFlags.CloseOnExec + + if atEmptyPath && len(pathname) == 0 { + executable = f + } else { + wd = f.Dirent + wd.IncRef() + if !fs.IsDir(wd.Inode.StableAttr) { + return 0, nil, syserror.ENOTDIR + } + } + } + if wd != nil { + defer wd.DecRef() + } // Load the new TaskContext. - maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet()) + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: t.MountNamespace(), + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: resolveFinal, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 4b3f043a2..b887fa9d7 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -15,6 +15,7 @@ package linux import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use timer.Destroy() - var remaining time.Duration - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - remaining = 0 - } else { - remaining = dur - after.Sub(start) + switch err { + case syserror.ETIMEDOUT: + // Slept for entire timeout. + return nil + case syserror.ErrInterrupted: + // Interrupted. + remaining := dur - after.Sub(start) if remaining < 0 { remaining = time.Duration(0) } - } - // Copy out remaining time. - if err != nil && rem != usermem.Addr(0) { - timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) - if err := copyTimespecOut(t, rem, &timeleft); err != nil { - return err + // Copy out remaining time. + if rem != 0 { + timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) + if err := copyTimespecOut(t, rem, &timeleft); err != nil { + return err + } } - } - - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - return nil - } - // If interrupted, arrange for a restart with the remaining duration. - if err == syserror.ErrInterrupted { + // Arrange for a restart with the remaining duration. t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{ c: c, duration: remaining, rem: rem, }) return kernel.ERESTART_RESTARTBLOCK + default: + panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } - - return err } // Nanosleep implements linux syscall Nanosleep(2). diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index 271ace08e..748e8dd8d 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - name, err := t.CopyInString(nameAddr, int(size)) - if err != nil { + name := make([]byte, size) + if _, err := t.CopyInBytes(nameAddr, name); err != nil { return 0, nil, err } - utsns.SetHostName(name) + utsns.SetHostName(string(name)) return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 27cd2c336..ad4b67806 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Pwritev2 implements linux syscall pwritev2(2). -// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality. func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go new file mode 100644 index 000000000..97d9a65ea --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -0,0 +1,169 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Getxattr implements linux syscall getxattr(2). +func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + valueLen := 0 + err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + value, err := getxattr(t, d, dirPath, nameAddr) + if err != nil { + return err + } + + valueLen = len(value) + if size == 0 { + return nil + } + if size > linux.XATTR_SIZE_MAX { + size = linux.XATTR_SIZE_MAX + } + if valueLen > int(size) { + return syserror.ERANGE + } + + _, err = t.CopyOutBytes(valueAddr, []byte(value)) + return err + }) + if err != nil { + return 0, nil, err + } + return uintptr(valueLen), nil, nil +} + +// getxattr implements getxattr from the given *fs.Dirent. +func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return "", syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil { + return "", err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return "", err + } + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + + return d.Inode.Getxattr(name) +} + +// Setxattr implements linux syscall setxattr(2). +func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Uint() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return 0, nil, syserror.EINVAL + } + + return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags) + }) +} + +// setxattr implements setxattr from the given *fs.Dirent. +func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil { + return err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + + if size > linux.XATTR_SIZE_MAX { + return syserror.E2BIG + } + buf := make([]byte, size) + if _, err = t.CopyInBytes(valueAddr, buf); err != nil { + return err + } + value := string(buf) + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + + return d.Inode.Setxattr(name, value) +} + +func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { + name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) + if err != nil { + if err == syserror.ENAMETOOLONG { + return "", syserror.ERANGE + } + return "", err + } + if len(name) == 0 { + return "", syserror.ERANGE + } + return name, nil +} + +func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error { + // Restrict xattrs to regular files and directories. + // + // In Linux, this restriction technically only applies to xattrs in the + // "user.*" namespace, but we don't allow any other xattr prefixes anyway. + if !fs.IsRegular(i.StableAttr) && !fs.IsDir(i.StableAttr) { + if perms.Write { + return syserror.EPERM + } + return syserror.ENODATA + } + + return i.CheckPermission(t, perms) +} diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 8aa6a3017..18e212dff 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,15 +1,15 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "seqatomic_parameters", out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,8 +36,8 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", + "//pkg/syncutil", "//pkg/syserror", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index b69603da3..fc7614fff 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -10,6 +11,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "unimplemented_syscall_cc_proto", + visibility = ["//visibility:public"], + deps = [":unimplemented_syscall_proto"], +) + go_proto_library( name = "unimplemented_syscall_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index a34c39540..c32fe3241 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "usage", srcs = [ diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index a5b4206bb..684f59a6b 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "addr_range", diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go index 8d88396ba..7898851b3 100644 --- a/pkg/sentry/usermem/bytes_io.go +++ b/pkg/sentry/usermem/bytes_io.go @@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { } func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { - blocks := make([]safemem.Block, 0, ars.NumRanges()) - for !ars.IsEmpty() { - ar := ars.Head() - n, err := b.rangeCheck(ar.Start, int(ar.Length())) - if n != 0 { - blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n])) + switch ars.NumRanges() { + case 0: + return safemem.BlockSeq{}, nil + case 1: + block, err := b.blockFromAddrRange(ars.Head()) + return safemem.BlockSeqOf(block), err + default: + blocks := make([]safemem.Block, 0, ars.NumRanges()) + for !ars.IsEmpty() { + block, err := b.blockFromAddrRange(ars.Head()) + if block.Len() != 0 { + blocks = append(blocks, block) + } + if err != nil { + return safemem.BlockSeqFromSlice(blocks), err + } + ars = ars.Tail() } - if err != nil { - return safemem.BlockSeqFromSlice(blocks), err - } - ars = ars.Tail() + return safemem.BlockSeqFromSlice(blocks), nil + } +} + +func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { + n, err := b.rangeCheck(ar.Start, int(ar.Length())) + if n == 0 { + return safemem.Block{}, err } - return safemem.BlockSeqFromSlice(blocks), nil + return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err } // BytesIOSequence returns an IOSequence representing the given byte slice. diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go index 6eced660a..7b1f312b1 100644 --- a/pkg/sentry/usermem/usermem.go +++ b/pkg/sentry/usermem/usermem.go @@ -16,6 +16,7 @@ package usermem import ( + "bytes" "errors" "io" "strconv" @@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) // Look for the terminating zero byte, which may have occurred before // hitting err. - for i, c := range buf[done : done+n] { - if c == 0 { - return stringFromImmutableBytes(buf[:done+i]), nil - } + if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { + return stringFromImmutableBytes(buf[:done+i]), nil } + done += n if err != nil { return stringFromImmutableBytes(buf[:done]), err diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0f247bf77..e3e554b88 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -11,13 +12,14 @@ go_library( "file_description.go", "file_description_impl_util.go", "filesystem.go", + "filesystem_impl_util.go", "filesystem_type.go", "mount.go", "mount_unsafe.go", "options.go", + "pathname.go", "permissions.go", "resolving_path.go", - "syscalls.go", "testutil.go", "vfs.go", ], @@ -31,9 +33,9 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 7847854bc..9aa133bcb 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -39,8 +39,8 @@ Mount references are held by: - Mount: Each referenced Mount holds a reference on its parent, which is the mount containing its mount point. -- VirtualFilesystem: A reference is held on all Mounts that are attached - (reachable by Mount traversal). +- VirtualFilesystem: A reference is held on each Mount that has not been + umounted. MountNamespace and FileDescription references are held by users of VFS. The expectation is that each `kernel.Task` holds a reference on its corresponding diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index 32cf9151b..705194ebc 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -24,6 +24,9 @@ type contextID int const ( // CtxMountNamespace is a Context.Value key for a MountNamespace. CtxMountNamespace contextID = iota + + // CtxRoot is a Context.Value key for a VFS root. + CtxRoot ) // MountNamespaceFromContext returns the MountNamespace used by ctx. It does @@ -35,3 +38,13 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace { } return nil } + +// RootFromContext returns the VFS root used by ctx. It takes a reference on +// the returned VirtualDentry. If ctx does not have a specific VFS root, +// RootFromContext returns a zero-value VirtualDentry. +func RootFromContext(ctx context.Context) VirtualDentry { + if v := ctx.Value(CtxRoot); v != nil { + return v.(VirtualDentry) + } + return VirtualDentry{} +} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 45912fc58..40f4c1d09 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,6 +16,7 @@ package vfs import ( "fmt" + "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/syserror" @@ -50,7 +51,7 @@ import ( // and not inodes. Furthermore, when parties outside the scope of VFS can // rename inodes on such filesystems, VFS generally cannot "follow" the rename, // both due to synchronization issues and because it may not even be able to -// name the destination path; this implies that it would in fact be *incorrect* +// name the destination path; this implies that it would in fact be incorrect // for Dentries to be associated with inodes on such filesystems. Consequently, // operations that are inode operations in Linux are FilesystemImpl methods // and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do @@ -84,6 +85,9 @@ type Dentry struct { // mounts is accessed using atomic memory operations. mounts uint32 + // mu synchronizes disowning and mounting over this Dentry. + mu sync.Mutex + // children are child Dentries. children map[string]*Dentry @@ -114,7 +118,7 @@ func (d *Dentry) Impl() DentryImpl { type DentryImpl interface { // IncRef increments the Dentry's reference count. A Dentry with a non-zero // reference count must remain coherent with the state of the filesystem. - IncRef(fs *Filesystem) + IncRef() // TryIncRef increments the Dentry's reference count and returns true. If // the Dentry's reference count is zero, TryIncRef may do nothing and @@ -122,10 +126,10 @@ type DentryImpl interface { // guarantee that the Dentry is coherent with the state of the filesystem.) // // TryIncRef does not require that a reference is held on the Dentry. - TryIncRef(fs *Filesystem) bool + TryIncRef() bool // DecRef decrements the Dentry's reference count. - DecRef(fs *Filesystem) + DecRef() } // IsDisowned returns true if d is disowned. @@ -142,16 +146,20 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } -func (d *Dentry) incRef(fs *Filesystem) { - d.impl.IncRef(fs) +// IncRef increments d's reference count. +func (d *Dentry) IncRef() { + d.impl.IncRef() } -func (d *Dentry) tryIncRef(fs *Filesystem) bool { - return d.impl.TryIncRef(fs) +// TryIncRef increments d's reference count and returns true. If d's reference +// count is zero, TryIncRef may instead do nothing and return false. +func (d *Dentry) TryIncRef() bool { + return d.impl.TryIncRef() } -func (d *Dentry) decRef(fs *Filesystem) { - d.impl.DecRef(fs) +// DecRef decrements d's reference count. +func (d *Dentry) DecRef() { + d.impl.DecRef() } // These functions are exported so that filesystem implementations can use @@ -228,36 +236,48 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent panic("d is already disowned") } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[d]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[d] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } - // Return with vfs.mountMu locked, which will be unlocked by - // AbortDeleteDentry or CommitDeleteDentry. + d.mu.Lock() + vfs.mountMu.Unlock() + // Return with d.mu locked to block attempts to mount over it; it will be + // unlocked by AbortDeleteDentry or CommitDeleteDentry. return nil } // AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion // fails. -func (vfs *VirtualFilesystem) AbortDeleteDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { + d.mu.Unlock() } // CommitDeleteDentry must be called after the file represented by d is // deleted, and causes d to become disowned. // +// CommitDeleteDentry is a mutator of d and d.Parent(). +// // Preconditions: PrepareDeleteDentry was previously called on d. func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { - delete(d.parent.children, d.name) + if d.parent != nil { + delete(d.parent.children, d.name) + } d.setDisowned() - // TODO: lazily unmount mounts at d - vfs.mountMu.RUnlock() + d.mu.Unlock() + if d.isMounted() { + vfs.forgetDisownedMountpoint(d) + } } // DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as // appropriate for in-memory filesystems that don't need to ensure that some // external state change succeeds before committing the deletion. +// +// DeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error { if err := vfs.PrepareDeleteDentry(mntns, d); err != nil { return err @@ -266,6 +286,27 @@ func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) err return nil } +// ForceDeleteDentry causes d to become disowned. It should only be used in +// cases where VFS has no ability to stop the deletion (e.g. d represents the +// local state of a file on a remote filesystem on which the file has already +// been deleted). +// +// ForceDeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. +func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) { + if checkInvariants { + if d.parent == nil { + panic("d is independent") + } + if d.IsDisowned() { + panic("d is already disowned") + } + } + d.mu.Lock() + vfs.CommitDeleteDentry(d) +} + // PrepareRenameDentry must be called before attempting to rename the file // represented by from. If to is not nil, it represents the file that will be // replaced or exchanged by the rename. If PrepareRenameDentry succeeds, the @@ -291,18 +332,21 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t } } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[from]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[from] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } if to != nil { - if _, ok := mntns.mountpoints[to]; ok { - vfs.mountMu.RUnlock() + if mntns.mountpoints[to] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } + to.mu.Lock() } - // Return with vfs.mountMu locked, which will be unlocked by + from.mu.Lock() + vfs.mountMu.Unlock() + // Return with from.mu and to.mu locked, which will be unlocked by // AbortRenameDentry, CommitRenameReplaceDentry, or // CommitRenameExchangeDentry. return nil @@ -310,38 +354,76 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t // AbortRenameDentry must be called after PrepareRenameDentry if the rename // fails. -func (vfs *VirtualFilesystem) AbortRenameDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { + from.mu.Unlock() + if to != nil { + to.mu.Unlock() + } } // CommitRenameReplaceDentry must be called after the file represented by from // is renamed without RENAME_EXCHANGE. If to is not nil, it represents the file // that was replaced by from. // +// CommitRenameReplaceDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. // newParent.Child(newName) == to. func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry, newName string, to *Dentry) { - if to != nil { - to.setDisowned() - // TODO: lazily unmount mounts at d - } if newParent.children == nil { newParent.children = make(map[string]*Dentry) } newParent.children[newName] = from from.parent = newParent from.name = newName - vfs.mountMu.RUnlock() + from.mu.Unlock() + if to != nil { + to.setDisowned() + to.mu.Unlock() + if to.isMounted() { + vfs.forgetDisownedMountpoint(to) + } + } } // CommitRenameExchangeDentry must be called after the files represented by // from and to are exchanged by rename(RENAME_EXCHANGE). // +// CommitRenameExchangeDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { from.parent, to.parent = to.parent, from.parent from.name, to.name = to.name, from.name from.parent.children[from.name] = from to.parent.children[to.name] = to - vfs.mountMu.RUnlock() + from.mu.Unlock() + to.mu.Unlock() +} + +// forgetDisownedMountpoint is called when a mount point is deleted to umount +// all mounts using it in all other mount namespaces. +// +// forgetDisownedMountpoint is analogous to Linux's +// fs/namespace.c:__detach_mounts(). +func (vfs *VirtualFilesystem) forgetDisownedMountpoint(d *Dentry) { + var ( + vdsToDecRef []VirtualDentry + mountsToDecRef []*Mount + ) + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + for mnt := range vfs.mountpoints[d] { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef) + } + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 86bde7fb3..c5a9adca3 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -47,15 +48,14 @@ type FileDescription struct { impl FileDescriptionImpl } -// Init must be called before first use of fd. It takes references on mnt and -// d. +// Init must be called before first use of fd. It takes ownership of references +// on mnt and d held by the caller. func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) { fd.refs = 1 fd.vd = VirtualDentry{ mount: mnt, dentry: d, } - fd.vd.IncRef() fd.impl = impl } @@ -64,6 +64,18 @@ func (fd *FileDescription) Impl() FileDescriptionImpl { return fd.impl } +// Mount returns the mount on which fd was opened. It does not take a reference +// on the returned Mount. +func (fd *FileDescription) Mount() *Mount { + return fd.vd.mount +} + +// Dentry returns the dentry at which fd was opened. It does not take a +// reference on the returned Dentry. +func (fd *FileDescription) Dentry() *Dentry { + return fd.vd.dentry +} + // VirtualDentry returns the location at which fd was opened. It does not take // a reference on the returned VirtualDentry. func (fd *FileDescription) VirtualDentry() VirtualDentry { @@ -75,6 +87,22 @@ func (fd *FileDescription) IncRef() { atomic.AddInt64(&fd.refs, 1) } +// TryIncRef increments fd's reference count and returns true. If fd's +// reference count is already zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fd. +func (fd *FileDescription) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fd.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) { + return true + } + } +} + // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef() { if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { @@ -102,7 +130,7 @@ type FileDescriptionImpl interface { // OnClose is called when a file descriptor representing the // FileDescription is closed. Note that returning a non-nil error does not // prevent the file descriptor from being closed. - OnClose() error + OnClose(ctx context.Context) error // StatusFlags returns file description status flags, as for // fcntl(F_GETFL). @@ -180,12 +208,26 @@ type FileDescriptionImpl interface { // ConfigureMMap mutates opts to implement mmap(2) for the file. Most // implementations that support memory mapping can call // GenericConfigureMMap with the appropriate memmap.Mappable. - ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error + ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // TODO: extended attributes; file locking + // Listxattr returns all extended attribute names for the file. + Listxattr(ctx context.Context) ([]string, error) + + // Getxattr returns the value associated with the given extended attribute + // for the file. + Getxattr(ctx context.Context, name string) (string, error) + + // Setxattr changes the value associated with the given extended attribute + // for the file. + Setxattr(ctx context.Context, opts SetxattrOptions) error + + // Removexattr removes the given extended attribute from the file. + Removexattr(ctx context.Context, name string) error + + // TODO: file locking } // Dirent holds the information contained in struct linux_dirent64. @@ -199,8 +241,11 @@ type Dirent struct { // Ino is the inode number. Ino uint64 - // Off is this Dirent's offset. - Off int64 + // NextOff is the offset of the *next* Dirent in the directory; that is, + // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will + // cause the next call to FileDescription.IterDirents() to yield the next + // Dirent. (The offset of the first Dirent in a directory is always 0.) + NextOff int64 } // IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents. @@ -211,3 +256,172 @@ type IterDirentsCallback interface { // called. Handle(dirent Dirent) bool } + +// OnClose is called when a file descriptor representing the FileDescription is +// closed. Returning a non-nil error should not prevent the file descriptor +// from being closed. +func (fd *FileDescription) OnClose(ctx context.Context) error { + return fd.impl.OnClose(ctx) +} + +// StatusFlags returns file description status flags, as for fcntl(F_GETFL). +func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { + flags, err := fd.impl.StatusFlags(ctx) + flags |= linux.O_LARGEFILE + return flags, err +} + +// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). +func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { + return fd.impl.SetStatusFlags(ctx, flags) +} + +// Stat returns metadata for the file represented by fd. +func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + return fd.impl.Stat(ctx, opts) +} + +// SetStat updates metadata for the file represented by fd. +func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error { + return fd.impl.SetStat(ctx, opts) +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return fd.impl.StatFS(ctx) +} + +// PRead reads from the file represented by fd into dst, starting at the given +// offset, and returns the number of bytes read. PRead is permitted to return +// partial reads with a nil error. +func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return fd.impl.PRead(ctx, dst, offset, opts) +} + +// Read is similar to PRead, but does not specify an offset. +func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return fd.impl.Read(ctx, dst, opts) +} + +// PWrite writes src to the file represented by fd, starting at the given +// offset, and returns the number of bytes written. PWrite is permitted to +// return partial writes with a nil error. +func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return fd.impl.PWrite(ctx, src, offset, opts) +} + +// Write is similar to PWrite, but does not specify an offset. +func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return fd.impl.Write(ctx, src, opts) +} + +// IterDirents invokes cb on each entry in the directory represented by fd. If +// IterDirents has been called since the last call to Seek, it continues +// iteration from the end of the last call. +func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return fd.impl.IterDirents(ctx, cb) +} + +// Seek changes fd's offset (assuming one exists) and returns its new value. +func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.impl.Seek(ctx, offset, whence) +} + +// Sync has the semantics of fsync(2). +func (fd *FileDescription) Sync(ctx context.Context) error { + return fd.impl.Sync(ctx) +} + +// ConfigureMMap mutates opts to implement mmap(2) for the file represented by +// fd. +func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.impl.ConfigureMMap(ctx, opts) +} + +// Ioctl implements the ioctl(2) syscall. +func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.impl.Ioctl(ctx, uio, args) +} + +// Listxattr returns all extended attribute names for the file represented by +// fd. +func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { + names, err := fd.impl.Listxattr(ctx) + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by default + // don't exist. + return nil, nil + } + return names, err +} + +// Getxattr returns the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { + return fd.impl.Getxattr(ctx, name) +} + +// Setxattr changes the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return fd.impl.Setxattr(ctx, opts) +} + +// Removexattr removes the given extended attribute from the file represented +// by fd. +func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { + return fd.impl.Removexattr(ctx, name) +} + +// SyncFS instructs the filesystem containing fd to execute the semantics of +// syncfs(2). +func (fd *FileDescription) SyncFS(ctx context.Context) error { + return fd.vd.mount.fs.impl.Sync(ctx) +} + +// MappedName implements memmap.MappingIdentity.MappedName. +func (fd *FileDescription) MappedName(ctx context.Context) string { + vfsroot := RootFromContext(ctx) + s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd) + if vfsroot.Ok() { + vfsroot.DecRef() + } + return s +} + +// DeviceID implements memmap.MappingIdentity.DeviceID. +func (fd *FileDescription) DeviceID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + // There is no STATX_DEV; we assume that Stat will return it if it's + // available regardless of mask. + Mask: 0, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return 0 + } + return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor)) +} + +// InodeID implements memmap.MappingIdentity.InodeID. +func (fd *FileDescription) InodeID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + Mask: linux.STATX_INO, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil || stat.Mask&linux.STATX_INO == 0 { + return 0 + } + return stat.Ino +} + +// Msync implements memmap.MappingIdentity.Msync. +func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error { + return fd.impl.Sync(ctx) +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index ba230da72..3df49991c 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{} // OnClose implements FileDescriptionImpl.OnClose analogously to // file_operations::flush == NULL in Linux. -func (FileDescriptionDefaultImpl) OnClose() error { +func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error { return nil } @@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error { // ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to // file_operations::mmap == NULL in Linux. -func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.ENODEV } @@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } +// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// inode_operations::listxattr == NULL in Linux. +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { + // This isn't exactly accurate; see FileDescription.Listxattr. + return nil, syserror.ENOTSUP +} + +// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { + return "", syserror.ENOTSUP +} + +// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return syserror.ENOTSUP +} + +// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { + return syserror.ENOTSUP +} + // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. @@ -252,3 +277,12 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6 fd.off = offset return offset, nil } + +// GenericConfigureMMap may be used by most implementations of +// FileDescriptionImpl.ConfigureMMap. +func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error { + opts.Mappable = m + opts.MappingIdentity = fd + fd.IncRef() + return nil +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 511b829fc..ac7799296 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -90,7 +90,7 @@ func TestGenCountFD(t *testing.T) { vfsObj := New() // vfs.New() vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{}) if err != nil { t.Fatalf("failed to create testfs root mount: %v", err) } @@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) { // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err := fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) { } // A second read without seeking is still at EOF. - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 0 || err != io.EOF { t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) } // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET) + n, err = fd.Seek(ctx, 0, linux.SEEK_SET) if n != 0 || err != nil { t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) } - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) { } // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{}) + n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 7a074b718..b766614e7 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -33,15 +34,28 @@ type Filesystem struct { // operations. refs int64 + // vfs is the VirtualFilesystem that uses this Filesystem. vfs is + // immutable. + vfs *VirtualFilesystem + // impl is the FilesystemImpl associated with this Filesystem. impl is // immutable. This should be the last field in Dentry. impl FilesystemImpl } // Init must be called before first use of fs. -func (fs *Filesystem) Init(impl FilesystemImpl) { +func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) { fs.refs = 1 + fs.vfs = vfsObj fs.impl = impl + vfsObj.filesystemsMu.Lock() + vfsObj.filesystems[fs] = struct{}{} + vfsObj.filesystemsMu.Unlock() +} + +// VirtualFilesystem returns the containing VirtualFilesystem. +func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem { + return fs.vfs } // Impl returns the FilesystemImpl associated with fs. @@ -49,14 +63,35 @@ func (fs *Filesystem) Impl() FilesystemImpl { return fs.impl } -func (fs *Filesystem) incRef() { +// IncRef increments fs' reference count. +func (fs *Filesystem) IncRef() { if atomic.AddInt64(&fs.refs, 1) <= 1 { - panic("Filesystem.incRef() called without holding a reference") + panic("Filesystem.IncRef() called without holding a reference") + } +} + +// TryIncRef increments fs' reference count and returns true. If fs' reference +// count is zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fs. +func (fs *Filesystem) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fs.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { + return true + } } } -func (fs *Filesystem) decRef() { +// DecRef decrements fs' reference count. +func (fs *Filesystem) DecRef() { if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { + fs.vfs.filesystemsMu.Lock() + delete(fs.vfs.filesystems, fs) + fs.vfs.filesystemsMu.Unlock() fs.impl.Release() } else if refs < 0 { panic("Filesystem.decRef() called without holding a reference") @@ -151,5 +186,70 @@ type FilesystemImpl interface { // UnlinkAt removes the non-directory file at rp. UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // TODO: d_path(); extended attributes; inotify_add_watch(); bind() + // ListxattrAt returns all extended attribute names for the file at rp. + ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + + // GetxattrAt returns the value associated with the given extended + // attribute for the file at rp. + GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + + // SetxattrAt changes the value associated with the given extended + // attribute for the file at rp. + SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + + // RemovexattrAt removes the given extended attribute from the file at rp. + RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + + // PrependPath prepends a path from vd to vd.Mount().Root() to b. + // + // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered + // before vd.Mount().Root(), PrependPath should stop prepending path + // components and return a PrependPathAtVFSRootError. + // + // If traversal of vd.Dentry()'s ancestors encounters an independent + // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a + // descendant of vd.Mount().Root()), PrependPath should stop prepending + // path components and return a PrependPathAtNonMountRootError. + // + // Filesystems for which Dentries do not have meaningful paths may prepend + // an arbitrary descriptive string to b and then return a + // PrependPathSyntheticError. + // + // Most implementations can acquire the appropriate locks to ensure that + // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of + // its ancestors, then call GenericPrependPath. + // + // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. + PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error + + // TODO: inotify_add_watch(); bind() +} + +// PrependPathAtVFSRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter the contextual VFS root. +type PrependPathAtVFSRootError struct{} + +// Error implements error.Error. +func (PrependPathAtVFSRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached VFS root" +} + +// PrependPathAtNonMountRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter an independent ancestor +// Dentry that is not the Mount root. +type PrependPathAtNonMountRootError struct{} + +// Error implements error.Error. +func (PrependPathAtNonMountRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root" +} + +// PrependPathSyntheticError is returned by implementations of +// FilesystemImpl.PrependPath() for which prepended names do not represent real +// paths. +type PrependPathSyntheticError struct{} + +// Error implements error.Error. +func (PrependPathSyntheticError) Error() string { + return "vfs.FilesystemImpl.PrependPath() prepended synthetic name" } diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go new file mode 100644 index 000000000..7315a588e --- /dev/null +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/fspath" +) + +// GenericParseMountOptions parses a comma-separated list of options of the +// form "key" or "key=value", where neither key nor value contain commas, and +// returns it as a map. If str contains duplicate keys, then the last value +// wins. For example: +// +// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'} +// +// GenericParseMountOptions is not appropriate if values may contain commas, +// e.g. in the case of the mpol mount option for tmpfs(5). +func GenericParseMountOptions(str string) map[string]string { + m := make(map[string]string) + for _, opt := range strings.Split(str, ",") { + if len(opt) > 0 { + res := strings.SplitN(opt, "=", 2) + if len(res) == 2 { + m[res[0]] = res[1] + } else { + m[opt] = "" + } + } + } + return m +} + +// GenericPrependPath may be used by implementations of +// FilesystemImpl.PrependPath() for which a single statically-determined lock +// or set of locks is sufficient to ensure its preconditions (as opposed to +// e.g. per-Dentry locks). +// +// Preconditions: Dentry.Name() and Dentry.Parent() must be held constant for +// vd.Dentry() and all of its ancestors. +func GenericPrependPath(vfsroot, vd VirtualDentry, b *fspath.Builder) error { + mnt, d := vd.mount, vd.dentry + for { + if mnt == vfsroot.mount && d == vfsroot.dentry { + return PrependPathAtVFSRootError{} + } + if d == mnt.root { + return nil + } + if d.parent == nil { + return PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index f401ad7f3..c335e206d 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -25,21 +25,21 @@ import ( // // FilesystemType is analogous to Linux's struct file_system_type. type FilesystemType interface { - // NewFilesystem returns a Filesystem configured by the given options, + // GetFilesystem returns a Filesystem configured by the given options, // along with its mount root. A reference is taken on the returned // Filesystem and Dentry. - NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) + GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) } -// NewFilesystemOptions contains options to FilesystemType.NewFilesystem. -type NewFilesystemOptions struct { +// GetFilesystemOptions contains options to FilesystemType.GetFilesystem. +type GetFilesystemOptions struct { // Data is the string passed as the 5th argument to mount(2), which is // usually a comma-separated list of filesystem-specific mount options. Data string // InternalData holds opaque FilesystemType-specific data. There is // intentionally no way for applications to specify InternalData; if it is - // not nil, the call to NewFilesystem originates from within the sentry. + // not nil, the call to GetFilesystem originates from within the sentry. InternalData interface{} } diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 11702f720..ec23ab0dd 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -38,16 +39,12 @@ import ( // Mount is analogous to Linux's struct mount. (gVisor does not distinguish // between struct mount and struct vfsmount.) type Mount struct { - // The lower 63 bits of refs are a reference count. The MSB of refs is set - // if the Mount has been eagerly unmounted, as by umount(2) without the - // MNT_DETACH flag. refs is accessed using atomic memory operations. - refs int64 - - // The lower 63 bits of writers is the number of calls to - // Mount.CheckBeginWrite() that have not yet been paired with a call to - // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. - // writers is accessed using atomic memory operations. - writers int64 + // vfs, fs, and root are immutable. References are held on fs and root. + // + // Invariant: root belongs to fs. + vfs *VirtualFilesystem + fs *Filesystem + root *Dentry // key is protected by VirtualFilesystem.mountMu and // VirtualFilesystem.mounts.seq, and may be nil. References are held on @@ -57,13 +54,29 @@ type Mount struct { // key.parent.fs. key mountKey - // fs, root, and ns are immutable. References are held on fs and root (but - // not ns). - // - // Invariant: root belongs to fs. - fs *Filesystem - root *Dentry - ns *MountNamespace + // ns is the namespace in which this Mount was mounted. ns is protected by + // VirtualFilesystem.mountMu. + ns *MountNamespace + + // The lower 63 bits of refs are a reference count. The MSB of refs is set + // if the Mount has been eagerly umounted, as by umount(2) without the + // MNT_DETACH flag. refs is accessed using atomic memory operations. + refs int64 + + // children is the set of all Mounts for which Mount.key.parent is this + // Mount. children is protected by VirtualFilesystem.mountMu. + children map[*Mount]struct{} + + // umounted is true if VFS.umountRecursiveLocked() has been called on this + // Mount. VirtualFilesystem does not hold a reference on Mounts for which + // umounted is true. umounted is protected by VirtualFilesystem.mountMu. + umounted bool + + // The lower 63 bits of writers is the number of calls to + // Mount.CheckBeginWrite() that have not yet been paired with a call to + // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. + // writers is accessed using atomic memory operations. + writers int64 } // A MountNamespace is a collection of Mounts. @@ -73,13 +86,16 @@ type Mount struct { // // MountNamespace is analogous to Linux's struct mnt_namespace. type MountNamespace struct { - refs int64 // accessed using atomic memory operations - // root is the MountNamespace's root mount. root is immutable. root *Mount - // mountpoints contains all Dentries which are mount points in this - // namespace. mountpoints is protected by VirtualFilesystem.mountMu. + // refs is the reference count. refs is accessed using atomic memory + // operations. + refs int64 + + // mountpoints maps all Dentries which are mount points in this namespace + // to the number of Mounts for which they are mount points. mountpoints is + // protected by VirtualFilesystem.mountMu. // // mountpoints is used to determine if a Dentry can be moved or removed // (which requires that the Dentry is not a mount point in the calling @@ -89,26 +105,27 @@ type MountNamespace struct { // MountNamespace; this is required to ensure that // VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate // correctly on unreferenced MountNamespaces. - mountpoints map[*Dentry]struct{} + mountpoints map[*Dentry]uint32 } // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *NewFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return nil, syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts) if err != nil { return nil, err } mntns := &MountNamespace{ refs: 1, - mountpoints: make(map[*Dentry]struct{}), + mountpoints: make(map[*Dentry]uint32), } mntns.root = &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, @@ -117,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } -// NewMount creates and mounts a new Filesystem. -func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *NewFilesystemOptions) error { +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return err } @@ -131,17 +148,19 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti // lock ordering. vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{}) if err != nil { - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return err } vfs.mountMu.Lock() + vd.dentry.mu.Lock() for { if vd.dentry.IsDisowned() { + vd.dentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef() - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -153,36 +172,272 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti if nextmnt == nil { break } - nextmnt.incRef() - nextmnt.root.incRef(nextmnt.fs) + // It's possible that nextmnt has been umounted but not disconnected, + // in which case vfs no longer holds a reference on it, and the last + // reference may be concurrently dropped even though we're holding + // vfs.mountMu. + if !nextmnt.tryIncMountedRef() { + break + } + // This can't fail since we're holding vfs.mountMu. + nextmnt.root.IncRef() + vd.dentry.mu.Unlock() vd.DecRef() vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, } + vd.dentry.mu.Lock() } // TODO: Linux requires that either both the mount point and the mount root // are directories, or neither are, and returns ENOTDIR if this is not the // case. mntns := vd.mount.ns mnt := &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, refs: 1, } - mnt.storeKey(vd.mount, vd.dentry) + vfs.mounts.seq.BeginWrite() + vfs.connectLocked(mnt, vd, mntns) + vfs.mounts.seq.EndWrite() + vd.dentry.mu.Unlock() + vfs.mountMu.Unlock() + return nil +} + +// UmountAt removes the Mount at the given path. +func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { + if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { + return syserror.EINVAL + } + + // MNT_FORCE is currently unimplemented except for the permission check. + if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { + return syserror.EPERM + } + + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return err + } + defer vd.DecRef() + if vd.dentry != vd.mount.root { + return syserror.EINVAL + } + vfs.mountMu.Lock() + if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns { + vfs.mountMu.Unlock() + return syserror.EINVAL + } + + // TODO(jamieliu): Linux special-cases umount of the caller's root, which + // we don't implement yet (we'll just fail it since the caller holds a + // reference on it). + + vfs.mounts.seq.BeginWrite() + if opts.Flags&linux.MNT_DETACH == 0 { + if len(vd.mount.children) != 0 { + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + // We are holding a reference on vd.mount. + expectedRefs := int64(1) + if !vd.mount.umounted { + expectedRefs = 2 + } + if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + } + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{ + eager: opts.Flags&linux.MNT_DETACH == 0, + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + return nil +} + +type umountRecursiveOptions struct { + // If eager is true, ensure that future calls to Mount.tryIncMountedRef() + // on umounted mounts fail. + // + // eager is analogous to Linux's UMOUNT_SYNC. + eager bool + + // If disconnectHierarchy is true, Mounts that are umounted hierarchically + // should be disconnected from their parents. (Mounts whose parents are not + // umounted, which in most cases means the Mount passed to the initial call + // to umountRecursiveLocked, are unconditionally disconnected for + // consistency with Linux.) + // + // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED. + disconnectHierarchy bool +} + +// umountRecursiveLocked marks mnt and its descendants as umounted. It does not +// release mount or dentry references; instead, it appends VirtualDentries and +// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef +// respectively, and returns updated slices. (This is necessary because +// filesystem locks possibly taken by DentryImpl.DecRef() may precede +// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.) +// +// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree(). +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. +func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) { + if !mnt.umounted { + mnt.umounted = true + mountsToDecRef = append(mountsToDecRef, mnt) + if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) { + vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt)) + } + } + if opts.eager { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs < 0 { + break + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) { + break + } + } + } + for child := range mnt.children { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef) + } + return vdsToDecRef, mountsToDecRef +} + +// connectLocked makes vd the mount parent/point for mnt. It consumes +// references held by vd. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. d.mu must be locked. mnt.parent() == nil. +func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { + mnt.storeKey(vd) + if vd.mount.children == nil { + vd.mount.children = make(map[*Mount]struct{}) + } + vd.mount.children[mnt] = struct{}{} atomic.AddUint32(&vd.dentry.mounts, 1) - mntns.mountpoints[vd.dentry] = struct{}{} + mntns.mountpoints[vd.dentry]++ + vfs.mounts.insertSeqed(mnt) vfsmpmounts, ok := vfs.mountpoints[vd.dentry] if !ok { vfsmpmounts = make(map[*Mount]struct{}) vfs.mountpoints[vd.dentry] = vfsmpmounts } vfsmpmounts[mnt] = struct{}{} - vfs.mounts.Insert(mnt) - vfs.mountMu.Unlock() - return nil +} + +// disconnectLocked makes vd have no mount parent/point and returns its old +// mount parent/point with a reference held. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. mnt.parent() != nil. +func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { + vd := mnt.loadKey() + mnt.storeKey(VirtualDentry{}) + delete(vd.mount.children, mnt) + atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 + mnt.ns.mountpoints[vd.dentry]-- + if mnt.ns.mountpoints[vd.dentry] == 0 { + delete(mnt.ns.mountpoints, vd.dentry) + } + vfs.mounts.removeSeqed(mnt) + vfsmpmounts := vfs.mountpoints[vd.dentry] + delete(vfsmpmounts, mnt) + if len(vfsmpmounts) == 0 { + delete(vfs.mountpoints, vd.dentry) + } + return vd +} + +// tryIncMountedRef increments mnt's reference count and returns true. If mnt's +// reference count is already zero, or has been eagerly umounted, +// tryIncMountedRef does nothing and returns false. +// +// tryIncMountedRef does not require that a reference is held on mnt. +func (mnt *Mount) tryIncMountedRef() bool { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + return false + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + return true + } + } +} + +// IncRef increments mnt's reference count. +func (mnt *Mount) IncRef() { + // In general, negative values for mnt.refs are valid because the MSB is + // the eager-unmount bit. + atomic.AddInt64(&mnt.refs, 1) +} + +// DecRef decrements mnt's reference count. +func (mnt *Mount) DecRef() { + refs := atomic.AddInt64(&mnt.refs, -1) + if refs&^math.MinInt64 == 0 { // mask out MSB + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + mnt.root.DecRef() + mnt.fs.DecRef() + if vd.Ok() { + vd.DecRef() + } + } +} + +// IncRef increments mntns' reference count. +func (mntns *MountNamespace) IncRef() { + if atomic.AddInt64(&mntns.refs, 1) <= 1 { + panic("MountNamespace.IncRef() called without holding a reference") + } +} + +// DecRef decrements mntns' reference count. +func (mntns *MountNamespace) DecRef(vfs *VirtualFilesystem) { + if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{ + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + } else if refs < 0 { + panic("MountNamespace.DecRef() called without holding a reference") + } } // getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes @@ -223,7 +478,7 @@ retryFirst: // Raced with umount. continue } - mnt.decRef() + mnt.DecRef() mnt = next d = next.root } @@ -231,12 +486,12 @@ retryFirst: } // getMountpointAt returns the mount point for the stack of Mounts including -// mnt. It takes a reference on the returned Mount and Dentry. If no such mount +// mnt. It takes a reference on the returned VirtualDentry. If no such mount // point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil). // // Preconditions: References are held on mnt and root. vfsroot is not (mnt, // mnt.root). -func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) (*Mount, *Dentry) { +func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // // - The caller must have already checked mnt against vfsroot. @@ -246,21 +501,26 @@ func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) // - We don't drop the caller's reference on mnt. retryFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryFirst } if parent == nil { - return nil, nil + return VirtualDentry{} } if !parent.tryIncMountedRef() { // Raced with umount. goto retryFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can only // happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() + goto retryFirst + } + if !vfs.mounts.seq.ReadOk(epoch) { + point.DecRef() + parent.DecRef() goto retryFirst } mnt = parent @@ -274,7 +534,7 @@ retryFirst: } retryNotFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryNotFirst } @@ -285,59 +545,23 @@ retryFirst: // Raced with umount. goto retryNotFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can // only happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() goto retryNotFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.decRef(parent.fs) - parent.decRef() + point.DecRef() + parent.DecRef() goto retryNotFirst } - d.decRef(mnt.fs) - mnt.decRef() + d.DecRef() + mnt.DecRef() mnt = parent d = point } - return mnt, d -} - -// tryIncMountedRef increments mnt's reference count and returns true. If mnt's -// reference count is already zero, or has been eagerly unmounted, -// tryIncMountedRef does nothing and returns false. -// -// tryIncMountedRef does not require that a reference is held on mnt. -func (mnt *Mount) tryIncMountedRef() bool { - for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted - return false - } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { - return true - } - } -} - -func (mnt *Mount) incRef() { - // In general, negative values for mnt.refs are valid because the MSB is - // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) -} - -func (mnt *Mount) decRef() { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - parent, point := mnt.loadKey() - if point != nil { - point.decRef(parent.fs) - parent.decRef() - } - mnt.root.decRef(mnt.fs) - mnt.fs.decRef() - } + return VirtualDentry{mnt, d} } // CheckBeginWrite increments the counter of in-progress write operations on @@ -360,7 +584,7 @@ func (mnt *Mount) EndWrite() { atomic.AddInt64(&mnt.writers, -1) } -// Preconditions: VirtualFilesystem.mountMu must be locked for writing. +// Preconditions: VirtualFilesystem.mountMu must be locked. func (mnt *Mount) setReadOnlyLocked(ro bool) error { if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro { return nil @@ -383,22 +607,6 @@ func (mnt *Mount) Filesystem() *Filesystem { return mnt.fs } -// IncRef increments mntns' reference count. -func (mntns *MountNamespace) IncRef() { - if atomic.AddInt64(&mntns.refs, 1) <= 1 { - panic("MountNamespace.IncRef() called without holding a reference") - } -} - -// DecRef decrements mntns' reference count. -func (mntns *MountNamespace) DecRef() { - if refs := atomic.AddInt64(&mntns.refs, 0); refs == 0 { - // TODO: unmount mntns.root - } else if refs < 0 { - panic("MountNamespace.DecRef() called without holding a reference") - } -} - // Root returns mntns' root. A reference is taken on the returned // VirtualDentry. func (mntns *MountNamespace) Root() VirtualDentry { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index f394d7483..adff0b94b 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -37,7 +37,7 @@ func TestMountTableInsertLookup(t *testing.T) { mt.Init() mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) mt.Insert(mount) if m := mt.Lookup(mount.parent(), mount.point()); m != mount { @@ -78,18 +78,10 @@ const enableComparativeBenchmarks = false func newBenchMount() *Mount { mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) return mount } -func vdkey(mnt *Mount) VirtualDentry { - parent, point := mnt.loadKey() - return VirtualDentry{ - mount: parent, - dentry: point, - } -} - func BenchmarkMountTableParallelLookup(b *testing.B) { for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { for _, numMounts := range benchNumMounts { @@ -101,7 +93,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { for i := 0; i < numMounts; i++ { mount := newBenchMount() mt.Insert(mount) - keys = append(keys, vdkey(mount)) + keys = append(keys, mount.loadKey()) } var ready sync.WaitGroup @@ -153,7 +145,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms[key] = mount keys = append(keys, key) } @@ -208,7 +200,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms.Store(key, mount) keys = append(keys, key) } @@ -290,7 +282,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -325,7 +317,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { var ms sync.Map for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -379,7 +371,7 @@ func BenchmarkMountMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } } @@ -399,7 +391,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } } @@ -432,13 +424,13 @@ func BenchmarkMountMapRemove(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } b.ResetTimer() for i := range mounts { mount := mounts[i] - delete(ms, vdkey(mount)) + delete(ms, mount.loadKey()) } } @@ -454,12 +446,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) { var ms sync.Map for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Delete(vdkey(mount)) + ms.Delete(mount.loadKey()) } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b0511aa40..ab13fa461 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // mountKey represents the location at which a Mount is mounted. It is @@ -38,16 +38,6 @@ type mountKey struct { point unsafe.Pointer // *Dentry } -// Invariant: mnt.key's fields are nil. parent and point are non-nil. -func (mnt *Mount) storeKey(parent *Mount, point *Dentry) { - atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(parent)) - atomic.StorePointer(&mnt.key.point, unsafe.Pointer(point)) -} - -func (mnt *Mount) loadKey() (*Mount, *Dentry) { - return (*Mount)(atomic.LoadPointer(&mnt.key.parent)), (*Dentry)(atomic.LoadPointer(&mnt.key.point)) -} - func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,6 +46,19 @@ func (mnt *Mount) point() *Dentry { return (*Dentry)(atomic.LoadPointer(&mnt.key.point)) } +func (mnt *Mount) loadKey() VirtualDentry { + return VirtualDentry{ + mount: mnt.parent(), + dentry: mnt.point(), + } +} + +// Invariant: mnt.key.parent == nil. vd.Ok(). +func (mnt *Mount) storeKey(vd VirtualDentry) { + atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) + atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) +} + // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). @@ -72,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq gvsync.SeqCount + seq syncutil.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of @@ -201,9 +204,19 @@ loop: // Insert inserts the given mount into mt. // -// Preconditions: There are no concurrent mutators of mt. mt must not already -// contain a Mount with the same mount point and parent. +// Preconditions: mt must not already contain a Mount with the same mount point +// and parent. func (mt *mountTable) Insert(mount *Mount) { + mt.seq.BeginWrite() + mt.insertSeqed(mount) + mt.seq.EndWrite() +} + +// insertSeqed inserts the given mount into mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must not +// already contain a Mount with the same mount point and parent. +func (mt *mountTable) insertSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) // We're under the maximum load factor if: @@ -215,10 +228,8 @@ func (mt *mountTable) Insert(mount *Mount) { tcap := uintptr(1) << order if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) { // Atomically insert the new element into the table. - mt.seq.BeginWrite() atomic.AddUint64(&mt.size, mtSizeLenOne) mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash) - mt.seq.EndWrite() return } @@ -241,8 +252,6 @@ func (mt *mountTable) Insert(mount *Mount) { for { oldSlot := (*mountSlot)(oldCur) if oldSlot.value != nil { - // Don't need to lock mt.seq yet since newSlots isn't visible - // to readers. mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash) } if oldCur == oldLast { @@ -252,11 +261,9 @@ func (mt *mountTable) Insert(mount *Mount) { } // Insert the new element into the new table. mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash) - // Atomically switch to the new table. - mt.seq.BeginWrite() + // Switch to the new table. atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne) atomic.StorePointer(&mt.slots, newSlots) - mt.seq.EndWrite() } // Preconditions: There are no concurrent mutators of the table (slots, cap). @@ -294,9 +301,18 @@ func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, has // Remove removes the given mount from mt. // -// Preconditions: There are no concurrent mutators of mt. mt must contain -// mount. +// Preconditions: mt must contain mount. func (mt *mountTable) Remove(mount *Mount) { + mt.seq.BeginWrite() + mt.removeSeqed(mount) + mt.seq.EndWrite() +} + +// removeSeqed removes the given mount from mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must contain +// mount. +func (mt *mountTable) removeSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 @@ -311,7 +327,6 @@ func (mt *mountTable) Remove(mount *Mount) { // backward until we either find an empty slot, or an element that // is already in its first-probed slot. (This is backward shift // deletion.) - mt.seq.BeginWrite() for { nextOff := (off + mountSlotBytes) & offmask nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff)) @@ -330,7 +345,6 @@ func (mt *mountTable) Remove(mount *Mount) { } atomic.StorePointer(&slot.value, nil) atomic.AddUint64(&mt.size, mtSizeLenNegOne) - mt.seq.EndWrite() return } if checkInvariants && slotValue == nil { diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 187e5410c..97ee4a446 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -31,14 +31,14 @@ type GetDentryOptions struct { // FilesystemImpl.MkdirAt(). type MkdirOptions struct { // Mode is the file mode bits for the created directory. - Mode uint16 + Mode linux.FileMode } // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). type MknodOptions struct { // Mode is the file type and mode bits for the created file. - Mode uint16 + Mode linux.FileMode // If Mode specifies a character or block device special file, DevMajor and // DevMinor are the major and minor device numbers for the created device. @@ -46,6 +46,12 @@ type MknodOptions struct { DevMinor uint32 } +// MountOptions contains options to VirtualFilesystem.MountAt(). +type MountOptions struct { + // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). + GetFilesystemOptions GetFilesystemOptions +} + // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). type OpenOptions struct { @@ -61,7 +67,7 @@ type OpenOptions struct { // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the // created file. - Mode uint16 + Mode linux.FileMode } // ReadOptions contains options to FileDescription.PRead(), @@ -95,6 +101,20 @@ type SetStatOptions struct { Stat linux.Statx } +// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), +// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and +// FileDescriptionImpl.Setxattr(). +type SetxattrOptions struct { + // Name is the name of the extended attribute being mutated. + Name string + + // Value is the extended attribute's new value. + Value string + + // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2). + Flags uint32 +} + // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). @@ -114,6 +134,12 @@ type StatOptions struct { Sync uint32 } +// UmountOptions contains options to VirtualFilesystem.UmountAt(). +type UmountOptions struct { + // Flags contains flags as specified for umount2(2). + Flags uint32 +} + // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go new file mode 100644 index 000000000..8e155654f --- /dev/null +++ b/pkg/sentry/vfs/pathname.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +var fspathBuilderPool = sync.Pool{ + New: func() interface{} { + return &fspath.Builder{} + }, +} + +func getFSPathBuilder() *fspath.Builder { + return fspathBuilderPool.Get().(*fspath.Builder) +} + +func putFSPathBuilder(b *fspath.Builder) { + // No methods can be called on b after b.String(), so reset it to its zero + // value (as returned by fspathBuilderPool.New) instead. + *b = fspath.Builder{} + fspathBuilderPool.Put(b) +} + +// PathnameWithDeleted returns an absolute pathname to vd, consistent with +// Linux's d_path(). In particular, if vd.Dentry() has been disowned, +// PathnameWithDeleted appends " (deleted)" to the returned pathname. +func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + + origD := vd.dentry +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + // GenericPrependPath() will have returned + // PrependPathAtVFSRootError in this case since it checks + // against vfsroot before mnt.root, but other implementations + // of FilesystemImpl.PrependPath() may return nil instead. + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + // continue loop + case PrependPathSyntheticError: + // Skip prepending "/" and appending " (deleted)". + return b.String(), nil + case PrependPathAtVFSRootError, PrependPathAtNonMountRootError: + break loop + default: + return "", err + } + } + b.PrependByte('/') + if origD.IsDisowned() { + b.AppendString(" (deleted)") + } + return b.String(), nil +} + +// PathnameForGetcwd returns an absolute pathname to vd, consistent with +// Linux's sys_getcwd(). +func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + if vd.dentry.IsDisowned() { + return "", syserror.ENOENT + } + + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + unreachable := false +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + unreachable = true + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + case PrependPathAtVFSRootError: + break loop + case PrependPathAtNonMountRootError, PrependPathSyntheticError: + unreachable = true + break loop + default: + return "", err + } + } + b.PrependByte('/') + if unreachable { + b.PrependString("(unreachable)") + } + return b.String(), nil +} + +// As of this writing, we do not have equivalents to: +// +// - d_absolute_path(), which returns EINVAL if (effectively) any call to +// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError. +// +// - dentry_path(), which does not walk up mounts (and only returns the path +// relative to Filesystem root), but also appends "//deleted" for disowned +// Dentries. +// +// These should be added as necessary. diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f8e74355c..f1edb0680 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { return false } } + +// CheckSetStat checks that creds has permission to change the metadata of a +// file with the given permissions, UID, and GID as specified by stat, subject +// to the rules of Linux's fs/attr.c:setattr_prepare(). +func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { + if stat.Mask&linux.STATX_MODE != 0 { + if !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + // TODO(b/30815691): "If the calling process is not privileged (Linux: + // does not have the CAP_FSETID capability), and the group of the file + // does not match the effective group ID of the process or one of its + // supplementary group IDs, the S_ISGID bit will be turned off, but + // this will not cause an error to be returned." - chmod(2) + } + if stat.Mask&linux.STATX_UID != 0 { + if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&linux.STATX_GID != 0 { + if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { + if !CanActAsOwner(creds, kuid) { + if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) { + return syserror.EPERM + } + // isDir is irrelevant in the following call to + // GenericCheckPermissions since ats == MayWrite means that + // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE + // applies, regardless of isDir. + if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil { + return err + } + } + } + return nil +} + +// CanActAsOwner returns true if creds can act as the owner of a file with the +// given owning UID, consistent with Linux's +// fs/inode.c:inode_owner_or_capable(). +func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool { + if creds.EffectiveKUID == kuid { + return true + } + return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok() +} + +// HasCapabilityOnFile returns true if creds has the given capability with +// respect to a file with the given owning UID and GID, consistent with Linux's +// kernel/capability.c:capable_wrt_inode_uidgid(). +func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool { + return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok() +} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 8d05c8583..621f5a6f8 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -149,20 +149,20 @@ func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { func (rp *ResolvingPath) decRefStartAndMount() { if rp.flags&rpflagsHaveStartRef != 0 { - rp.start.decRef(rp.mount.fs) + rp.start.DecRef() } if rp.flags&rpflagsHaveMountRef != 0 { - rp.mount.decRef() + rp.mount.DecRef() } } func (rp *ResolvingPath) releaseErrorState() { if rp.nextStart != nil { - rp.nextStart.decRef(rp.nextMount.fs) + rp.nextStart.DecRef() rp.nextStart = nil } if rp.nextMount != nil { - rp.nextMount.decRef() + rp.nextMount.DecRef() rp.nextMount = nil } } @@ -269,11 +269,11 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) { parent = d } else if d == rp.mount.root { // At mount root ... - mnt, mntpt := rp.vfs.getMountpointAt(rp.mount, rp.root) - if mnt != nil { + vd := rp.vfs.getMountpointAt(rp.mount, rp.root) + if vd.Ok() { // ... of non-root mount. - rp.nextMount = mnt - rp.nextStart = mntpt + rp.nextMount = vd.mount + rp.nextStart = vd.dentry return nil, resolveMountRootError{} } // ... of root mount. diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go deleted file mode 100644 index 23f2b9e08..000000000 --- a/pkg/sentry/vfs/syscalls.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -// PathOperation specifies the path operated on by a VFS method. -// -// PathOperation is passed to VFS methods by pointer to reduce memory copying: -// it's somewhat large and should never escape. (Options structs are passed by -// pointer to VFS and FileDescription methods for the same reason.) -type PathOperation struct { - // Root is the VFS root. References on Root are borrowed from the provider - // of the PathOperation. - // - // Invariants: Root.Ok(). - Root VirtualDentry - - // Start is the starting point for the path traversal. References on Start - // are borrowed from the provider of the PathOperation (i.e. the caller of - // the VFS method to which the PathOperation was passed). - // - // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. - Start VirtualDentry - - // Path is the pathname traversed by this operation. - Pathname string - - // If FollowFinalSymlink is true, and the Dentry traversed by the final - // path component represents a symbolic link, the symbolic link should be - // followed. - FollowFinalSymlink bool -} - -// GetDentryAt returns a VirtualDentry representing the given path, at which a -// file must exist. A reference is taken on the returned VirtualDentry. -func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return VirtualDentry{}, err - } - for { - d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) - if err == nil { - vd := VirtualDentry{ - mount: rp.mount, - dentry: d, - } - rp.mount.incRef() - vfs.putResolvingPath(rp) - return vd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return VirtualDentry{}, err - } - } -} - -// MkdirAt creates a directory at the given path. -func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { - // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is - // also honored." - mkdir(2) - opts.Mode &= 01777 - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } - for { - err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return err - } - } -} - -// OpenAt returns a FileDescription providing access to the file at the given -// path. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { - // Remove: - // - // - O_LARGEFILE, which we always report in FileDescription status flags - // since only 64-bit architectures are supported at this time. - // - // - O_CLOEXEC, which affects file descriptors and therefore must be - // handled outside of VFS. - // - // - Unknown flags. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE - // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. - if opts.Flags&linux.O_SYNC != 0 { - opts.Flags |= linux.O_DSYNC - } - // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified - // with O_DIRECTORY and a writable access mode (to ensure that it fails on - // filesystem implementations that do not support it). - if opts.Flags&linux.O_TMPFILE != 0 { - if opts.Flags&linux.O_DIRECTORY == 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { - return nil, syserror.EINVAL - } - } - // O_PATH causes most other flags to be ignored. - if opts.Flags&linux.O_PATH != 0 { - opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH - } - // "On Linux, the following bits are also honored in mode: [S_ISUID, - // S_ISGID, S_ISVTX]" - open(2) - opts.Mode &= 07777 - - if opts.Flags&linux.O_NOFOLLOW != 0 { - pop.FollowFinalSymlink = false - } - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil, err - } - if opts.Flags&linux.O_DIRECTORY != 0 { - rp.mustBeDir = true - rp.mustBeDirOrig = true - } - for { - fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return fd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return nil, err - } - } -} - -// StatAt returns metadata for the file at the given path. -func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statx{}, err - } - for { - stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return stat, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return linux.Statx{}, err - } - } -} - -// StatusFlags returns file description status flags. -func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { - flags, err := fd.impl.StatusFlags(ctx) - flags |= linux.O_LARGEFILE - return flags, err -} - -// SetStatusFlags sets file description status flags. -func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - return fd.impl.SetStatusFlags(ctx, flags) -} - -// TODO: -// -// - VFS.SyncAllFilesystems() for sync(2) -// -// - Something for syncfs(2) -// -// - VFS.LinkAt() -// -// - VFS.MknodAt() -// -// - VFS.ReadlinkAt() -// -// - VFS.RenameAt() -// -// - VFS.RmdirAt() -// -// - VFS.SetStatAt() -// -// - VFS.StatFSAt() -// -// - VFS.SymlinkAt() -// -// - VFS.UnlinkAt() -// -// - FileDescription.(almost everything) diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go index 70b192ece..d94117bce 100644 --- a/pkg/sentry/vfs/testutil.go +++ b/pkg/sentry/vfs/testutil.go @@ -15,7 +15,10 @@ package vfs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -33,10 +36,10 @@ type FDTestFilesystem struct { vfsfs Filesystem } -// NewFilesystem implements FilesystemType.NewFilesystem. -func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) { +// GetFilesystem implements FilesystemType.GetFilesystem. +func (fstype FDTestFilesystemType) GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) { var fs FDTestFilesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) return &fs.vfsfs, fs.NewDentry(), nil } @@ -114,6 +117,32 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err return syserror.EPERM } +// ListxattrAt implements FilesystemImpl.ListxattrAt. +func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { + return nil, syserror.EPERM +} + +// GetxattrAt implements FilesystemImpl.GetxattrAt. +func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { + return "", syserror.EPERM +} + +// SetxattrAt implements FilesystemImpl.SetxattrAt. +func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { + return syserror.EPERM +} + +// RemovexattrAt implements FilesystemImpl.RemovexattrAt. +func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { + return syserror.EPERM +} + +// PrependPath implements FilesystemImpl.PrependPath. +func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error { + b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry))) + return PrependPathSyntheticError{} +} + type fdTestDentry struct { vfsd Dentry } @@ -126,14 +155,14 @@ func (fs *FDTestFilesystem) NewDentry() *Dentry { } // IncRef implements DentryImpl.IncRef. -func (d *fdTestDentry) IncRef(vfsfs *Filesystem) { +func (d *fdTestDentry) IncRef() { } // TryIncRef implements DentryImpl.TryIncRef. -func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool { +func (d *fdTestDentry) TryIncRef() bool { return true } // DecRef implements DentryImpl.DecRef. -func (d *fdTestDentry) DecRef(vfsfs *Filesystem) { +func (d *fdTestDentry) DecRef() { } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 4a8a69540..e60898d7c 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -16,13 +16,24 @@ // // Lock order: // -// Filesystem implementation locks +// FilesystemImpl/FileDescriptionImpl locks // VirtualFilesystem.mountMu +// Dentry.mu +// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry +// VirtualFilesystem.filesystemsMu // VirtualFilesystem.fsTypesMu +// +// Locking Dentry.mu in multiple Dentries requires holding +// VirtualFilesystem.mountMu. package vfs import ( "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" ) // A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts. @@ -33,7 +44,7 @@ type VirtualFilesystem struct { // mountMu serializes mount mutations. // // mountMu is analogous to Linux's namespace_sem. - mountMu sync.RWMutex + mountMu sync.Mutex // mounts maps (mount parent, mount point) pairs to mounts. (Since mounts // are uniquely namespaced, including mount parent in the key correctly @@ -52,7 +63,7 @@ type VirtualFilesystem struct { // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. // - // mountpoints is used to find mounts that must be unmounted due to + // mountpoints is used to find mounts that must be umounted due to // removal of a mount point Dentry from another mount namespace. ("A file // or directory that is a mount point in one namespace that is not a mount // point in another namespace, may be renamed, unlinked, or removed @@ -62,6 +73,11 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // filesystems contains all Filesystems. filesystems is protected by + // filesystemsMu. + filesystemsMu sync.Mutex + filesystems map[*Filesystem]struct{} + // fsTypes contains all FilesystemTypes that are usable in the // VirtualFilesystem. fsTypes is protected by fsTypesMu. fsTypesMu sync.RWMutex @@ -72,12 +88,466 @@ type VirtualFilesystem struct { func New() *VirtualFilesystem { vfs := &VirtualFilesystem{ mountpoints: make(map[*Dentry]map[*Mount]struct{}), + filesystems: make(map[*Filesystem]struct{}), fsTypes: make(map[string]FilesystemType), } vfs.mounts.Init() return vfs } +// PathOperation specifies the path operated on by a VFS method. +// +// PathOperation is passed to VFS methods by pointer to reduce memory copying: +// it's somewhat large and should never escape. (Options structs are passed by +// pointer to VFS and FileDescription methods for the same reason.) +type PathOperation struct { + // Root is the VFS root. References on Root are borrowed from the provider + // of the PathOperation. + // + // Invariants: Root.Ok(). + Root VirtualDentry + + // Start is the starting point for the path traversal. References on Start + // are borrowed from the provider of the PathOperation (i.e. the caller of + // the VFS method to which the PathOperation was passed). + // + // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. + Start VirtualDentry + + // Path is the pathname traversed by this operation. + Pathname string + + // If FollowFinalSymlink is true, and the Dentry traversed by the final + // path component represents a symbolic link, the symbolic link should be + // followed. + FollowFinalSymlink bool +} + +// GetDentryAt returns a VirtualDentry representing the given path, at which a +// file must exist. A reference is taken on the returned VirtualDentry. +func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return VirtualDentry{}, err + } + for { + d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) + if err == nil { + vd := VirtualDentry{ + mount: rp.mount, + dentry: d, + } + rp.mount.IncRef() + vfs.putResolvingPath(rp) + return vd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return VirtualDentry{}, err + } + } +} + +// LinkAt creates a hard link at newpop representing the existing file at +// oldpop. +func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// MkdirAt creates a directory at the given path. +func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { + // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is + // also honored." - mkdir(2) + opts.Mode &= 0777 | linux.S_ISVTX + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// MknodAt creates a file of the given mode at the given path. It returns an +// error from the syserror package. +func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil + } + for { + if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + vfs.putResolvingPath(rp) + return nil + } + // Handle mount traversals. + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// OpenAt returns a FileDescription providing access to the file at the given +// path. A reference is taken on the returned FileDescription. +func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + // Remove: + // + // - O_LARGEFILE, which we always report in FileDescription status flags + // since only 64-bit architectures are supported at this time. + // + // - O_CLOEXEC, which affects file descriptors and therefore must be + // handled outside of VFS. + // + // - Unknown flags. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE + // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. + if opts.Flags&linux.O_SYNC != 0 { + opts.Flags |= linux.O_DSYNC + } + // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified + // with O_DIRECTORY and a writable access mode (to ensure that it fails on + // filesystem implementations that do not support it). + if opts.Flags&linux.O_TMPFILE != 0 { + if opts.Flags&linux.O_DIRECTORY == 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_CREAT != 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { + return nil, syserror.EINVAL + } + } + // O_PATH causes most other flags to be ignored. + if opts.Flags&linux.O_PATH != 0 { + opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH + } + // "On Linux, the following bits are also honored in mode: [S_ISUID, + // S_ISGID, S_ISVTX]" - open(2) + opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + + if opts.Flags&linux.O_NOFOLLOW != 0 { + pop.FollowFinalSymlink = false + } + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + if opts.Flags&linux.O_DIRECTORY != 0 { + rp.mustBeDir = true + rp.mustBeDirOrig = true + } + for { + fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return fd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// ReadlinkAt returns the target of the symbolic link at the given path. +func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return target, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// RenameAt renames the file at oldpop to newpop. +func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// RmdirAt removes the directory at the given path. +func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RmdirAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SetStatAt changes metadata for the file at the given path. +func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// StatAt returns metadata for the file at the given path. +func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statx{}, err + } + for { + stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return stat, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statx{}, err + } + } +} + +// StatFSAt returns metadata for the filesystem containing the file at the +// given path. +func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statfs{}, err + } + for { + statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return statfs, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statfs{}, err + } + } +} + +// SymlinkAt creates a symbolic link at the given path with the given target. +func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// UnlinkAt deletes the non-directory file at the given path. +func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.UnlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// ListxattrAt returns all extended attribute names for the file at the given +// path. +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + for { + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return names, nil + } + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by + // default don't exist. + vfs.putResolvingPath(rp) + return nil, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// GetxattrAt returns the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return val, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// SetxattrAt changes the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// RemovexattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SyncAllFilesystems has the semantics of Linux's sync(2). +func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + fss := make(map[*Filesystem]struct{}) + vfs.filesystemsMu.Lock() + for fs := range vfs.filesystems { + if !fs.TryIncRef() { + continue + } + fss[fs] = struct{}{} + } + vfs.filesystemsMu.Unlock() + var retErr error + for fs := range fss { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef() + } + return retErr +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). @@ -111,15 +581,15 @@ func (vd VirtualDentry) Ok() bool { // IncRef increments the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) IncRef() { - vd.mount.incRef() - vd.dentry.incRef(vd.mount.fs) + vd.mount.IncRef() + vd.dentry.IncRef() } // DecRef decrements the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) DecRef() { - vd.dentry.decRef(vd.mount.fs) - vd.mount.decRef() + vd.dentry.DecRef() + vd.mount.DecRef() } // Mount returns the Mount associated with vd. It does not take a reference on diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 145102c0d..5e4611333 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -42,8 +42,35 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) -// DefaultTimeout is a resonable timeout value for most applications. -const DefaultTimeout = 3 * time.Minute +// Opts configures the watchdog. +type Opts struct { + // TaskTimeout is the amount of time to allow a task to execute the + // same syscall without blocking before it's declared stuck. + TaskTimeout time.Duration + + // TaskTimeoutAction indicates what action to take when a stuck tasks + // is detected. + TaskTimeoutAction Action + + // StartupTimeout is the amount of time to allow between watchdog + // creation and calling watchdog.Start. + StartupTimeout time.Duration + + // StartupTimeoutAction indicates what action to take when + // watchdog.Start is not called within the timeout. + StartupTimeoutAction Action +} + +// DefaultOpts is a default set of options for the watchdog. +var DefaultOpts = Opts{ + // Task timeout. + TaskTimeout: 3 * time.Minute, + TaskTimeoutAction: LogWarning, + + // Startup timeout. + StartupTimeout: 30 * time.Second, + StartupTimeoutAction: LogWarning, +} // descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period // is discounted from task's last update time. It's set high enough that small scheduling delays won't @@ -61,6 +88,7 @@ type Action int const ( // LogWarning logs warning message followed by stack trace. LogWarning Action = iota + // Panic will do the same logging as LogWarning and panic(). Panic ) @@ -80,17 +108,13 @@ func (a Action) String() string { // Watchdog is the main watchdog class. It controls a goroutine that periodically // analyses all tasks and reports if any of them appear to be stuck. type Watchdog struct { + // Configuration options are embedded. + Opts + // period indicates how often to check all tasks. It's calculated based on - // 'taskTimeout'. + // opts.TaskTimeout. period time.Duration - // taskTimeout is the amount of time to allow a task to execute the same syscall - // without blocking before it's declared stuck. - taskTimeout time.Duration - - // timeoutAction indicates what action to take when a stuck tasks is detected. - timeoutAction Action - // k is where the tasks come from. k *kernel.Kernel @@ -113,8 +137,12 @@ type Watchdog struct { // mu protects the fields below. mu sync.Mutex - // started is true if the watchdog has been started before. - started bool + // running is true if the watchdog is running. + running bool + + // startCalled is true if Start has ever been called. It remains true + // even if Stop is called. + startCalled bool } type offender struct { @@ -122,58 +150,81 @@ type offender struct { } // New creates a new watchdog. -func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog { - // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much. - period := taskTimeout / 4 - return &Watchdog{ - k: k, - period: period, - taskTimeout: taskTimeout, - timeoutAction: a, - offenders: make(map[*kernel.Task]*offender), - stop: make(chan struct{}), - done: make(chan struct{}), +func New(k *kernel.Kernel, opts Opts) *Watchdog { + // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much. + period := opts.TaskTimeout / 4 + w := &Watchdog{ + Opts: opts, + k: k, + period: period, + offenders: make(map[*kernel.Task]*offender), + stop: make(chan struct{}), + done: make(chan struct{}), + } + + // Handle StartupTimeout if it exists. + if w.StartupTimeout > 0 { + log.Infof("Watchdog waiting %v for startup", w.StartupTimeout) + go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore. } + + return w } // Start starts the watchdog. func (w *Watchdog) Start() { - if w.taskTimeout == 0 { - log.Infof("Watchdog disabled") - return - } - w.mu.Lock() defer w.mu.Unlock() - if w.started { + w.startCalled = true + + if w.running { return } + if w.TaskTimeout == 0 { + log.Infof("Watchdog task timeout disabled") + return + } w.lastRun = w.k.MonotonicClock().Now() - log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction) + log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction) go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore. - w.started = true + w.running = true } // Stop requests the watchdog to stop and wait for it. func (w *Watchdog) Stop() { - if w.taskTimeout == 0 { + if w.TaskTimeout == 0 { return } w.mu.Lock() defer w.mu.Unlock() - if !w.started { + if !w.running { return } log.Infof("Stopping watchdog") w.stop <- struct{}{} <-w.done - w.started = false + w.running = false log.Infof("Watchdog stopped") } +// waitForStart waits for Start to be called and takes action if it does not +// happen within the startup timeout. +func (w *Watchdog) waitForStart() { + <-time.After(w.StartupTimeout) + w.mu.Lock() + defer w.mu.Unlock() + if w.startCalled { + // We are fine. + return + } + var buf bytes.Buffer + buf.WriteString("Watchdog.Start() not called within %s:\n") + w.doAction(w.StartupTimeoutAction, false, &buf) +} + // loop is the main watchdog routine. It only returns when 'Stop()' is called. func (w *Watchdog) loop() { // Loop until someone stops it. @@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() { select { case <-done: - case <-time.After(w.taskTimeout): + case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. // No one is wathching the watchdog watcher though. w.reportStuckWatchdog() @@ -231,12 +282,14 @@ func (w *Watchdog) runTurn() { if tsched.State == kernel.TaskGoroutineRunningSys { lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick))) elapsed := now.Sub(lastUpdateTime) - discount - if elapsed > w.taskTimeout { + if elapsed > w.TaskTimeout { tc, ok := w.offenders[t] if !ok { // New stuck task detected. // - // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel. + // Note that tasks blocked doing IO may be considered stuck in kernel, + // unless they are surrounded b + // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() newTaskFound = true @@ -261,28 +314,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - w.onStuckTask(newTaskFound, &buf) + + // Dump stack only if a new task is detected or if it sometime has + // passed since the last time a stack dump was generated. + skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, skipStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer buf.WriteString("Watchdog goroutine is stuck:\n") - w.onStuckTask(true, &buf) + w.doAction(w.TaskTimeoutAction, false, &buf) } -func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { - switch w.timeoutAction { +// doAction will take the given action. If the action is LogWarnind and +// skipStack is true, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { + switch action { case LogWarning: - // Dump stack only if a new task is detected or if it sometime has passed since - // the last time a stack dump was generated. - if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { + if skipStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) - } else { - log.TracebackAll(msg.String()) - w.lastStackDump = time.Now() + return + } + log.TracebackAll(msg.String()) + w.lastStackDump = time.Now() case Panic: // Panic will skip over running tasks, which is likely the culprit here. So manually @@ -301,5 +360,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + default: + panic(fmt.Sprintf("Unknown watchdog action %v", action)) + } } diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index 00665c939..a23c86fb1 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "sleep", srcs = [ "commit_amd64.s", + "commit_arm64.s", "commit_asm.go", "commit_noasm.go", "sleep_unsafe.go", diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s new file mode 100644 index 000000000..d0ef15b20 --- /dev/null +++ b/pkg/sleep/commit_arm64.s @@ -0,0 +1,38 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +#define preparingG 1 + +// See commit_noasm.go for a description of commitSleep. +// +// func commitSleep(g uintptr, waitingG *uintptr) bool +TEXT ·commitSleep(SB),NOSPLIT,$0-24 + MOVD waitingG+8(FP), R0 + MOVD $preparingG, R1 + MOVD G+0(FP), R2 + + // Store the G in waitingG if it's still preparingG. If it's anything + // else it means a waker has aborted the sleep. +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE ok + STLXR R2, (R0), R3 + CBNZ R3, again +ok: + CSET EQ, R0 + MOVB R0, ret+16(FP) + RET diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go index 35e2cc337..75728a97d 100644 --- a/pkg/sleep/commit_asm.go +++ b/pkg/sleep/commit_asm.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package sleep diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go index 686b1da3d..3af447fb9 100644 --- a/pkg/sleep/commit_noasm.go +++ b/pkg/sleep/commit_noasm.go @@ -13,7 +13,7 @@ // limitations under the License. // +build !race -// +build !amd64 +// +build !amd64,!arm64 package sleep diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 8f5e60a25..acbf0229b 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/BUILD b/pkg/state/BUILD index c0f3c658d..be93750bf 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,10 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "addr_range", out = "addr_range.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 47e6b878a..590c241a3 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState { // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { + // ctx is the decode context. + ctx context.Context + // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 5d9409a45..c5118d3a9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -16,6 +16,7 @@ package state import ( "container/list" + "context" "encoding/binary" "fmt" "io" @@ -38,6 +39,9 @@ type queuedObject struct { // The encoding process is a breadth-first traversal of the object graph. The // inherent races and dependencies are much simpler than the decode case. type encodeState struct { + // ctx is the encode context. + ctx context.Context + // lastID is the last object ID. // // See idsByObject for context. Because of the special zero encoding diff --git a/pkg/state/map.go b/pkg/state/map.go index 7e6fefed4..4f3ebb0da 100644 --- a/pkg/state/map.go +++ b/pkg/state/map.go @@ -15,6 +15,7 @@ package state import ( + "context" "fmt" "reflect" "sort" @@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) { // data dependencies have been cleared. m.os.callbacks = append(m.os.callbacks, fn) } + +// Context returns the current context object. +func (m Map) Context() context.Context { + if m.es != nil { + return m.es.ctx + } else if m.ds != nil { + return m.ds.ctx + } + return context.Background() // No context. +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto index 952289069..5ebcfb151 100644 --- a/pkg/state/object.proto +++ b/pkg/state/object.proto @@ -18,8 +18,8 @@ package gvisor.state.statefile; // Slice is a slice value. message Slice { - uint32 length = 1; - uint32 capacity = 2; + uint32 length = 1; + uint32 capacity = 2; uint64 ref_value = 3; } @@ -30,13 +30,13 @@ message Array { // Map is a map value. message Map { - repeated Object keys = 1; + repeated Object keys = 1; repeated Object values = 2; } // Interface is an interface value. message Interface { - string type = 1; + string type = 1; Object value = 2; } @@ -47,7 +47,7 @@ message Struct { // Field encodes a single field. message Field { - string name = 1; + string name = 1; Object value = 2; } @@ -113,28 +113,28 @@ message Float32s { // Note that ref_value references an Object.id, below. message Object { oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; + bool bool_value = 1; + bytes string_value = 2; + int64 int64_value = 3; + uint64 uint64_value = 4; + double double_value = 5; + uint64 ref_value = 6; + Slice slice_value = 7; + Array array_value = 8; + Interface interface_value = 9; + Struct struct_value = 10; + Map map_value = 11; + bytes byte_array_value = 12; + Uint16s uint16_array_value = 13; + Uint32s uint32_array_value = 14; + Uint64s uint64_array_value = 15; + Uintptrs uintptr_array_value = 16; + Int8s int8_array_value = 17; + Int16s int16_array_value = 18; + Int32s int32_array_value = 19; + Int64s int64_array_value = 20; + Bools bool_array_value = 21; + Float64s float64_array_value = 22; + Float32s float32_array_value = 23; } } diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..dbe507ab4 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 7c24bbcda..d7221e9e8 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "io/ioutil" "math" "reflect" @@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) { saveBuffer := &bytes.Buffer{} saveObjectPtr := reflect.New(reflect.TypeOf(root)) saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Save failed unexpectedly: %v", err) continue } else if err != nil { @@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) { // Load a new copy of the object. loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Load failed unexpectedly: %v", err) continue } else if err != nil { @@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) { bs := buildObject(b.N) var stats Stats b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { + if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { b.Errorf("save failed: %v", err) } b.StopTimer() @@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) { bs := buildObject(b.N) var newBS benchStruct buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { + if err := Save(context.Background(), buf, bs, nil); err != nil { b.Errorf("save failed: %v", err) } var stats Stats b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { + if err := Load(context.Background(), buf, &newBS, &stats); err != nil { b.Errorf("load failed: %v", err) } b.StopTimer() diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e70f4a79f..8a865d229 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD new file mode 100644 index 000000000..cb1f41628 --- /dev/null +++ b/pkg/syncutil/BUILD @@ -0,0 +1,52 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "syncutil", + srcs = [ + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/syncutil", +) + +go_test( + name = "syncutil_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":syncutil"], +) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/syncutil/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/syncutil/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/syncutil/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD new file mode 100644 index 000000000..63f411a90 --- /dev/null +++ b/pkg/syncutil/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/syncutil:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/syncutil/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go new file mode 100644 index 000000000..ffaf7ecc7 --- /dev/null +++ b/pkg/syncutil/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package syncutil + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..51e11555d --- /dev/null +++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package syncutil + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go new file mode 100644 index 000000000..348675baa --- /dev/null +++ b/pkg/syncutil/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package syncutil + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go new file mode 100644 index 000000000..0a0a9deda --- /dev/null +++ b/pkg/syncutil/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package syncutil + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go new file mode 100644 index 000000000..206067ec1 --- /dev/null +++ b/pkg/syncutil/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package syncutil + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go new file mode 100644 index 000000000..cb6d2eb22 --- /dev/null +++ b/pkg/syncutil/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/syncutil" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if syncutil.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if syncutil.RaceEnabled { + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD new file mode 100644 index 000000000..ba18f3238 --- /dev/null +++ b/pkg/syncutil/seqatomictest/BUILD @@ -0,0 +1,35 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/syncutil:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", + deps = [ + "//pkg/syncutil", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = [ + "//pkg/syncutil", + ], +) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..b0db44999 --- /dev/null +++ b/pkg/syncutil/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/syncutil" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq syncutil.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq syncutil.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq syncutil.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq syncutil.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go new file mode 100644 index 000000000..11d8dbfaa --- /dev/null +++ b/pkg/syncutil/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncutil + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go new file mode 100644 index 000000000..14d6aedea --- /dev/null +++ b/pkg/syncutil/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncutil + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go new file mode 100644 index 000000000..66e750d06 --- /dev/null +++ b/pkg/syncutil/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package syncutil provides synchronization primitives. +package syncutil diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index b149f9e02..bd3f9fd28 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index df37c7d5a..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,10 +1,13 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 0d2637ee4..78df5a0b1 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 672f026b2..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { return nil, err @@ -148,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -200,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 3301967fb..d6c31bfa2 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "buffer", diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 4287464f3..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable { return Prependable{buf: v, usedIdx: 0} } +// NewEmptyPrependableFromView creates a new prependable buffer from a View. +func NewEmptyPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: len(v)} +} + // View returns a View of the backing buffer that contains all prepended // data so far. func (p Prependable) View() View { @@ -72,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index 4cecfb989..b6fa6fc37 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 096ad71ab..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker { // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and // potentially additional ICMPv6 header fields. +// +// ICMPv6 will validate the checksum field before calling checkers. func ICMPv6(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) + if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) + } + for _, f := range checkers { f(t, icmp) } @@ -686,3 +693,64 @@ func ICMPv6Code(want byte) TransportChecker { } } } + +// NDP creates a checker that checks that the packet contains a valid NDP +// message for type of ty, with potentially additional checks specified by +// checkers. +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDP message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.NDPPayload()); got < minSize { + t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// NDPNS creates a checker that checks that the packet contains a valid NDP +// Neighbor Solicitation message (as per the raw wire format), with potentially +// additional checks specified by checkers. +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNS message as far as the size of the messages concerned. The values within +// the message are up to checkers to validate. +func NDPNS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) +} + +// NDPNSTargetAddress creates a checker that checks the Target Address field of +// a header.NDPNeighborSolicit. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNS message as far as the size is concerned. +func NDPNSTargetAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + + if got := ns.TargetAddress(); got != want { + t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) + } + } +} diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 29b30be9c..e648efa71 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,9 +7,7 @@ go_library( name = "jenkins", srcs = ["jenkins.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 76ef02f13..f1d837196 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,6 +1,7 @@ -package(licenses = ["notice"]) +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +package(licenses = ["notice"]) go_library( name = "header", @@ -15,6 +16,10 @@ go_library( "ipv4.go", "ipv6.go", "ipv6_fragment.go", + "ndp_neighbor_advert.go", + "ndp_neighbor_solicit.go", + "ndp_options.go", + "ndp_router_advert.go", "tcp.go", "udp.go", ], @@ -29,13 +34,32 @@ go_library( ) go_test( - name = "header_test", + name = "header_x_test", size = "small", srcs = [ + "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], deps = [ ":header", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) + +go_test( + name = "header_test", + size = "small", + srcs = [ + "eth_test.go", + "ndp_test.go", + ], + embed = [":header"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 39a4d69be..9749c7f4d 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -23,11 +23,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -func calculateChecksum(buf []byte, initial uint32) uint16 { +func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + l := len(buf) - if l&1 != 0 { + odd = l&1 != 0 + if odd { l-- v += uint32(buf[l]) << 8 } @@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the @@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in @@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - var odd bool + return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) +} + +// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the +// bytes in the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { + odd := false sum := initial for _, v := range vv.Views() { if len(v) == 0 { continue } - s := uint32(sum) - if odd { - s += uint32(v[0]) - v = v[1:] + + if off >= len(v) { + off -= len(v) + continue + } + v = v[off:] + + l := len(v) + if l > size { + l = size + } + v = v[:l] + + sum, odd = calculateChecksum(v, odd, uint32(sum)) + + size -= len(v) + if size == 0 { + break } - odd = len(v)&1 != 0 - sum = calculateChecksum(v, s) + off = 0 } return sum } diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go new file mode 100644 index 000000000..86b466c1c --- /dev/null +++ b/pkg/tcpip/header/checksum_test.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestChecksumVVWithOffset(t *testing.T) { + testCases := []struct { + name string + vv buffer.VectorisedView + off, size int + initial uint16 + want uint16 + }{ + { + name: "empty", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 0, + want: 0, + }, + { + name: "OneView", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 5, + want: 1294, + }, + { + name: "TwoViews", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 0, + size: 11, + want: 33819, + }, + { + name: "TwoViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 1, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 7, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithInitial", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), + }), + initial: 77, + off: 7, + size: 11, + want: 33896, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { + t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + } + v := tc.vv.ToView() + v.TrimFront(tc.off) + v.CapLength(tc.size) + if got, want := header.Checksum(v, tc.initial), tc.want; got != want { + t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + } + }) + } +} diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 4c3d3311f..f5d2c127f 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -48,8 +48,48 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 + + // unspecifiedEthernetAddress is the unspecified ethernet address + // (all bits set to 0). + unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + + // unicastMulticastFlagMask is the mask of the least significant bit in + // the first octet (in network byte order) of an ethernet address that + // determines whether the ethernet address is a unicast or multicast. If + // the masked bit is a 1, then the address is a multicast, unicast + // otherwise. + // + // See the IEEE Std 802-2001 document for more details. Specifically, + // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf: + // "A 48-bit universal address consists of two parts. The first 24 bits + // correspond to the OUI as assigned by the IEEE, expect that the + // assignee may set the LSB of the first octet to 1 for group addresses + // or set it to 0 for individual addresses." + unicastMulticastFlagMask = 1 + + // unicastMulticastFlagByteIdx is the byte that holds the + // unicast/multicast flag. See unicastMulticastFlagMask. + unicastMulticastFlagByteIdx = 0 +) + +const ( + // EthernetProtocolAll is a catch-all for all protocols carried inside + // an ethernet frame. It is mainly used to create packet sockets that + // capture all traffic. + EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003 + + // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype. + EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200 ) +// Ethertypes holds the protocol numbers describing the payload of an ethernet +// frame. These types aren't necessarily supported by netstack, but can be used +// to catch all traffic of a type via packet endpoints. +var Ethertypes = []tcpip.NetworkProtocolNumber{ + EthernetProtocolAll, + EthernetProtocolPUP, +} + // SourceAddress returns the "MAC source" field of the ethernet frame header. func (b Ethernet) SourceAddress() tcpip.LinkAddress { return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) @@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } + +// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// ethernet address. +func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + // Must be of the right length. + if len(addr) != EthernetAddressSize { + return false + } + + // Must not be unspecified. + if addr == unspecifiedEthernetAddress { + return false + } + + // Must not be a multicast. + if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { + return false + } + + // addr is a valid unicast ethernet address. + return true +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go new file mode 100644 index 000000000..6634c90f5 --- /dev/null +++ b/pkg/tcpip/header/eth_test.go @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestIsValidUnicastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Valid", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 1125a7d14..b4037b6c8 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -25,6 +25,12 @@ import ( type ICMPv6 []byte const ( + // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the + // sum of the size of the ICMPv6 Type, Code and Checksum fields, as + // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6 + // message body begins. + ICMPv6HeaderSize = 4 + // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. ICMPv6MinimumSize = 8 @@ -37,10 +43,16 @@ const ( // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize + + // ICMPv6NeighborAdvertMinimumSize is the minimum size of a + // neighbor advertisement packet. + ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize - // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. - ICMPv6NeighborAdvertSize = 32 + // ICMPv6NeighborAdvertSize is size of a neighbor advertisement + // including the NDP Target Link Layer option for an Ethernet + // address. + ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. ICMPv6EchoMinimumSize = 8 @@ -68,6 +80,13 @@ const ( // icmpv6SequenceOffset is the offset of the sequence field // in a ICMPv6 Echo Request/Reply message. icmpv6SequenceOffset = 6 + + // NDPHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + NDPHopLimit = 255 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -113,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } -// SetChecksum calculates and sets the ICMP checksum field. +// SetChecksum sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } @@ -166,12 +185,19 @@ func (b ICMPv6) SetSequence(sequence uint16) { binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) } +// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6 +// packet's message body as defined by RFC 4443 section 2.1; the portion of the +// ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) NDPPayload() []byte { + return b[ICMPv6HeaderSize:] +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } -// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { // Calculate the IPv6 pseudo-header upper-layer checksum. diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 554632a64..e5360e7c1 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -284,7 +284,7 @@ func (b IPv4) IsValid(pktSize int) bool { hlen := int(b.HeaderLength()) tlen := int(b.TotalLength()) - if hlen > tlen || tlen > pktSize { + if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize { return false } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 093850e25..fc671e439 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,15 +76,37 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 - // IPv6Any is the non-routable IPv6 "any" meta address. + // IPv6Any is the non-routable IPv6 "any" meta address. It is also + // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { @@ -93,6 +115,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined +// by RFC 4291 section 2.5.6. +// +// The prefix is fe80::/64 +var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{ + Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 64, +} + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { @@ -221,6 +252,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool { return addr[0] == 0xff } +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr[0] != 0xff +} + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. @@ -229,27 +278,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..42c5c6fc1 --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go new file mode 100644 index 000000000..505c92668 --- /dev/null +++ b/pkg/tcpip/header/ndp_neighbor_advert.go @@ -0,0 +1,110 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "gvisor.dev/gvisor/pkg/tcpip" + +// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will +// only contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.4 for more details. +type NDPNeighborAdvert []byte + +const ( + // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor + // Advertisement message (body of an ICMPv6 packet). + NDPNAMinimumSize = 20 + + // ndpNATargetAddressOffset is the start of the Target Address + // field within an NDPNeighborAdvert. + ndpNATargetAddressOffset = 4 + + // ndpNAOptionsOffset is the start of the NDP options in an + // NDPNeighborAdvert. + ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize + + // ndpNAFlagsOffset is the offset of the flags within an + // NDPNeighborAdvert + ndpNAFlagsOffset = 0 + + // ndpNARouterFlagMask is the mask of the Router Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNARouterFlagMask = (1 << 7) + + // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNASolicitedFlagMask = (1 << 6) + + // ndpNAOverrideFlagMask is the mask of the Override Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNAOverrideFlagMask = (1 << 5) +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborAdvert) TargetAddress() tcpip.Address { + return tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr) +} + +// RouterFlag returns the value of the Router Flag field. +func (b NDPNeighborAdvert) RouterFlag() bool { + return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0 +} + +// SetRouterFlag sets the value in the Router Flag field. +func (b NDPNeighborAdvert) SetRouterFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNARouterFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask + } +} + +// SolicitedFlag returns the value of the Solicited Flag field. +func (b NDPNeighborAdvert) SolicitedFlag() bool { + return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0 +} + +// SetSolicitedFlag sets the value in the Solicited Flag field. +func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask + } +} + +// OverrideFlag returns the value of the Override Flag field. +func (b NDPNeighborAdvert) OverrideFlag() bool { + return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0 +} + +// SetOverrideFlag sets the value in the Override Flag field. +func (b NDPNeighborAdvert) SetOverrideFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask + } +} + +// Options returns an NDPOptions of the the options body. +func (b NDPNeighborAdvert) Options() NDPOptions { + return NDPOptions(b[ndpNAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go new file mode 100644 index 000000000..3a1b8e139 --- /dev/null +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go @@ -0,0 +1,52 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "gvisor.dev/gvisor/pkg/tcpip" + +// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only +// contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.3 for more details. +type NDPNeighborSolicit []byte + +const ( + // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor + // Solicitation message (body of an ICMPv6 packet). + NDPNSMinimumSize = 20 + + // ndpNSTargetAddessOffset is the start of the Target Address + // field within an NDPNeighborSolicit. + ndpNSTargetAddessOffset = 4 + + // ndpNSOptionsOffset is the start of the NDP options in an + // NDPNeighborSolicit. + ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborSolicit) TargetAddress() tcpip.Address { + return tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPNeighborSolicit) Options() NDPOptions { + return NDPOptions(b[ndpNSOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go new file mode 100644 index 000000000..06e0bace2 --- /dev/null +++ b/pkg/tcpip/header/ndp_options.go @@ -0,0 +1,589 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "errors" + "math" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // NDPTargetLinkLayerAddressOptionType is the type of the Target + // Link-Layer Address option, as per RFC 4861 section 4.6.1. + NDPTargetLinkLayerAddressOptionType = 2 + + // ndpTargetEthernetLinkLayerAddressSize is the size of a Target + // Link Layer Option for an Ethernet address. + ndpTargetEthernetLinkLayerAddressSize = 8 + + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType = 3 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = (1 << 7) + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + + // lengthByteUnits is the multiplier factor for the Length field of an + // NDP option. That is, the length field for NDP options is in units of + // 8 octets, as per RFC 4861 section 4.6. + lengthByteUnits = 8 +) + +var ( + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. + // + // This is a variable instead of a constant so that tests can change + // this value to a smaller value. It should only be modified by tests. + NDPInfiniteLifetime = time.Second * math.MaxUint32 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + // The NDPOptions this NDPOptionIterator is iterating over. + opts NDPOptions +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") + ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if len(i.opts) == 0 { + return nil, true, nil + } + + // Do we have enough bytes for an NDP option that has a Length + // field of at least 1? Note, 0 in the Length field is invalid. + if len(i.opts) < lengthByteUnits { + return nil, true, ErrNDPOptBufExhausted + } + + // Get the Type field. + t := i.opts[0] + + // Get the Length field. + l := i.opts[1] + + // This would indicate an erroneous NDP option as the Length + // field should never be 0. + if l == 0 { + return nil, true, ErrNDPOptZeroLength + } + + // How many bytes are in the option body? + numBytes := int(l) * lengthByteUnits + numBodyBytes := numBytes - 2 + + potentialBody := i.opts[2:] + + // This would indicate an erroenous NDPOptions buffer as we ran + // out of the buffer in the middle of an NDP option. + if left := len(potentialBody); left < numBodyBytes { + return nil, true, ErrNDPOptBufExhausted + } + + // Get only the options body, leaving the rest of the options + // buffer alone. + body := potentialBody[:numBodyBytes] + + // Update opts with the remaining options body. + i.opts = i.opts[numBytes:] + + switch t { + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, ErrNDPOptMalformedBody + } + + return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + +// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. +type NDPOptions []byte + +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{opts: b} + + if check { + for it2 := it; true; { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + +// Serialize serializes the provided list of NDP options into o. +// +// Note, b must be of sufficient size to hold all the options in s. See +// NDPOptionsSerializer.Length for details on the getting the total size +// of a serialized NDPOptionsSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { + done := 0 + + for _, o := range s { + l := paddedLength(o) + + if l == 0 { + continue + } + + b[0] = o.Type() + + // We know this safe because paddedLength would have returned + // 0 if o had an invalid length (> 255 * lengthByteUnits). + b[1] = uint8(l / lengthByteUnits) + + // Serialize NDP option body. + used := o.serializeInto(b[2:]) + + // Zero out remaining (padding) bytes, if any exists. + for i := used + 2; i < l; i++ { + b[i] = 0 + } + + b = b[l:] + done += l + } + + return done +} + +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + // Type returns the type of the receiver. + Type() uint8 + + // Length returns the length of the body of the receiver, in bytes. + Length() int + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) int +} + +// paddedLength returns the length of o, in bytes, with any padding bytes, if +// required. +func paddedLength(o NDPOption) int { + l := o.Length() + + if l == 0 { + return 0 + } + + // Length excludes the 2 Type and Length bytes. + l += 2 + + // Add extra bytes if needed to make sure the option is + // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1 + // to l and then stripping off the last few LSBits from l. This will + // make sure that l is rounded up to the nearest unit of + // lengthByteUnits. This works since lengthByteUnits is a power of 2 + // (= 8). + mask := lengthByteUnits - 1 + l += mask + l &^= mask + + if l/lengthByteUnits > 255 { + // Should never happen because an option can only have a max + // value of 255 for its Length field, so just return 0 so this + // option does not get serialized. + // + // Returning 0 here will make sure that this option does not get + // serialized when NDPOptions.Serialize is called with the + // NDPOptionsSerializer that holds this option, effectively + // skipping this option during serialization. Also note that + // a value of zero for the Length field in an NDP option is + // invalid so this is another sign to the caller that this NDP + // option is malformed, as per RFC 4861 section 4.6. + return 0 + } + + return l +} + +// NDPOptionsSerializer is a serializer for NDP options. +type NDPOptionsSerializer []NDPOption + +// Length returns the total number of bytes required to serialize. +func (b NDPOptionsSerializer) Length() int { + l := 0 + + for _, o := range b { + l += paddedLength(o) + } + + return l +} + +// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPTargetLinkLayerAddressOption tcpip.LinkAddress + +// Type implements NDPOption.Type. +func (o NDPTargetLinkLayerAddressOption) Type() uint8 { + return NDPTargetLinkLayerAddressOptionType +} + +// Length implements NDPOption.Length. +func (o NDPTargetLinkLayerAddressOption) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() uint8 { + return NDPPrefixInformationType +} + +// Length implements NDPOption.Length. +func (o NDPPrefixInformation) Length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + for i := range reserved2 { + reserved2[i] = 0 + } + + return used +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go new file mode 100644 index 000000000..bf7610863 --- /dev/null +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -0,0 +1,112 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "time" +) + +// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.2 for more details. +type NDPRouterAdvert []byte + +const ( + // NDPRAMinimumSize is the minimum size of a valid NDP Router + // Advertisement message (body of an ICMPv6 packet). + NDPRAMinimumSize = 12 + + // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field + // within an NDPRouterAdvert. + ndpRACurrHopLimitOffset = 0 + + // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags + // within an NDPRouterAdvert. + ndpRAFlagsOffset = 1 + + // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address + // Configuration flag within the bit-field/flags byte of an + // NDPRouterAdvert. + ndpRAManagedAddrConfFlagMask = (1 << 7) + + // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag + // within the bit-field/flags byte of an NDPRouterAdvert. + ndpRAOtherConfFlagMask = (1 << 6) + + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime + // field within an NDPRouterAdvert. + ndpRARouterLifetimeOffset = 2 + + // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time + // field within an NDPRouterAdvert. + ndpRAReachableTimeOffset = 4 + + // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer + // field within an NDPRouterAdvert. + ndpRARetransTimerOffset = 8 + + // ndpRAOptionsOffset is the start of the NDP options in an + // NDPRouterAdvert. + ndpRAOptionsOffset = 12 +) + +// CurrHopLimit returns the value of the Curr Hop Limit field. +func (b NDPRouterAdvert) CurrHopLimit() uint8 { + return b[ndpRACurrHopLimitOffset] +} + +// ManagedAddrConfFlag returns the value of the Managed Address Configuration +// flag. +func (b NDPRouterAdvert) ManagedAddrConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0 +} + +// OtherConfFlag returns the value of the Other Configuration flag. +func (b NDPRouterAdvert) OtherConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 +} + +// RouterLifetime returns the lifetime associated with the default router. A +// value of 0 means the source of the Router Advertisement is not a default +// router and SHOULD NOT appear on the default router list. Note, a value of 0 +// only means that the router should not be used as a default router, it does +// not apply to other information contained in the Router Advertisement. +func (b NDPRouterAdvert) RouterLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.2. + return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:])) +} + +// ReachableTime returns the time that a node assumes a neighbor is reachable +// after having received a reachability confirmation. A value of 0 means +// that it is unspecified by the source of the Router Advertisement message. +func (b NDPRouterAdvert) ReachableTime() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:])) +} + +// RetransTimer returns the time between retransmitted Neighbor Solicitation +// messages. A value of 0 means that it is unspecified by the source of the +// Router Advertisement message. +func (b NDPRouterAdvert) RetransTimer() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:])) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterAdvert) Options() NDPOptions { + return NDPOptions(b[ndpRAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go new file mode 100644 index 000000000..2c439d70c --- /dev/null +++ b/pkg/tcpip/header/ndp_test.go @@ -0,0 +1,785 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bytes" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. +func TestNDPNeighborSolicit(t *testing.T) { + b := []byte{ + 0, 0, 0, 0, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } + + // Test getting the Target Address. + ns := NDPNeighborSolicit(b) + addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + if got := ns.TargetAddress(); got != addr { + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) + } + + // Test updating the Target Address. + addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + ns.SetTargetAddress(addr2) + if got := ns.TargetAddress(); got != addr2 { + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) + } + // Make sure the address got updated in the backing buffer. + if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) + } +} + +// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. +func TestNDPNeighborAdvert(t *testing.T) { + b := []byte{ + 160, 0, 0, 0, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } + + // Test getting the Target Address. + na := NDPNeighborAdvert(b) + addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + if got := na.TargetAddress(); got != addr { + t.Errorf("got TargetAddress = %s, want %s", got, addr) + } + + // Test getting the Router Flag. + if got := na.RouterFlag(); !got { + t.Errorf("got RouterFlag = false, want = true") + } + + // Test getting the Solicited Flag. + if got := na.SolicitedFlag(); got { + t.Errorf("got SolicitedFlag = true, want = false") + } + + // Test getting the Override Flag. + if got := na.OverrideFlag(); !got { + t.Errorf("got OverrideFlag = false, want = true") + } + + // Test updating the Target Address. + addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + na.SetTargetAddress(addr2) + if got := na.TargetAddress(); got != addr2 { + t.Errorf("got TargetAddress = %s, want %s", got, addr2) + } + // Make sure the address got updated in the backing buffer. + if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) + } + + // Test updating the Router Flag. + na.SetRouterFlag(false) + if got := na.RouterFlag(); got { + t.Errorf("got RouterFlag = true, want = false") + } + + // Test updating the Solicited Flag. + na.SetSolicitedFlag(true) + if got := na.SolicitedFlag(); !got { + t.Errorf("got SolicitedFlag = false, want = true") + } + + // Test updating the Override Flag. + na.SetOverrideFlag(false) + if got := na.OverrideFlag(); got { + t.Errorf("got OverrideFlag = true, want = false") + } + + // Make sure flags got updated in the backing buffer. + if got := b[ndpNAFlagsOffset]; got != 64 { + t.Errorf("got flags byte = %d, want = 64") + } +} + +func TestNDPRouterAdvert(t *testing.T) { + b := []byte{ + 64, 128, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + + ra := NDPRouterAdvert(b) + + if got := ra.CurrHopLimit(); got != 64 { + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) + } + + if got := ra.ManagedAddrConfFlag(); !got { + t.Errorf("got ManagedAddrConfFlag = false, want = true") + } + + if got := ra.OtherConfFlag(); got { + t.Errorf("got OtherConfFlag = true, want = false") + } + + if got, want := ra.RouterLifetime(), time.Second*258; got != want { + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + } +} + +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } + +} + +// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a +// NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { + tests := []struct { + name string + buf []byte + expectedBuf []byte + addr tcpip.LinkAddress + }{ + { + "Ethernet", + make([]byte, 8), + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + "\x01\x02\x03\x04\x05\x06", + }, + { + "Padding", + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + "\x01\x02\x03\x04\x05\x06\x07\x08", + }, + { + "Empty", + []byte{}, + []byte{}, + "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + serializer := NDPOptionsSerializer{ + NDPTargetLinkLayerAddressOption(test.addr), + } + if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { + t.Fatalf("got Length = %d, want = %d", got, want) + } + opts.Serialize(serializer) + if !bytes.Equal(test.buf, test.expectedBuf) { + t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPPrefixInformationOption tests the field getters and serialization of a +// NDPPrefixInformation. +func TestNDPPrefixInformationOption(t *testing.T) { + b := []byte{ + 43, 127, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 5, 5, 5, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPPrefixInformation(b), + } + opts.Serialize(serializer) + expectedBuf := []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + if !bytes.Equal(targetBuf, expectedBuf) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) + + if got := pi.Type(); got != 3 { + t.Errorf("got Type = %d, want = 3", got) + } + + if got := pi.Length(); got != 30 { + t.Errorf("got Length = %d, want = 30", got) + } + + if got := pi.PrefixLength(); got != 43 { + t.Errorf("got PrefixLength = %d, want = 43", got) + } + + if pi.OnLinkFlag() { + t.Error("got OnLinkFlag = true, want = false") + } + + if !pi.AutonomousAddressConfigurationFlag() { + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") + } + + if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { + t.Errorf("got ValidLifetime = %d, want = %d", got, want) + } + + if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) + } + + if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expected error + }{ + { + "ZeroLengthField", + []byte{0, 0, 0, 0, 0, 0, 0, 0}, + ErrNDPOptZeroLength, + }, + { + "ValidTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { + "ValidPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "TooSmallPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + ErrNDPOptBufExhausted, + }, + { + "InvalidPrefixInformationLength", + []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + ErrNDPOptMalformedBody, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformation", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); err != test.expected { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + }) + } +} + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Taret Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 3fc14bacd..cc5f531e2 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "iptables", srcs = [ diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 97a794986..7dbc05754 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -6,7 +6,7 @@ go_library( name = "channel", srcs = ["channel.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..70188551f 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -25,10 +25,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt tcpip.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO } // Endpoint is link layer endpoint that stores outbound packets in a channel @@ -44,14 +43,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. @@ -67,14 +64,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil)) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -98,7 +95,7 @@ func (e *Endpoint) MTU() uint32 { func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { caps := stack.LinkEndpointCapabilities(0) if e.GSO { - caps |= stack.CapabilityGSO + caps |= stack.CapabilityHardwareGSO } return caps } @@ -120,12 +117,55 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + p := PacketInfo{ + Pkt: pkt, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + payloadView := pkts[0].Data.ToView() + n := 0 +packetLoop: + for _, pkt := range pkts { + off := pkt.DataOffset + size := pkt.DataSize + p := PacketInfo{ + Pkt: tcpip.PacketBuffer{ + Header: pkt.Header, + Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), + }, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + n++ + default: + break packetLoop + } + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: tcpip.PacketBuffer{Data: vv}, + Proto: 0, + GSO: nil, } select { @@ -135,3 +175,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 74fbbb896..897c94821 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -13,9 +14,7 @@ go_library( "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..fa8a703d9 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,6 +41,7 @@ package fdbased import ( "fmt" + "sync" "syscall" "golang.org/x/sys/unix" @@ -81,6 +82,19 @@ const ( PacketMMap ) +func (p PacketDispatchMode) String() string { + switch p { + case Readv: + return "Readv" + case RecvMMsg: + return "RecvMMsg" + case PacketMMap: + return "PacketMMap" + default: + return fmt.Sprintf("unknown packet dispatch mode %v", p) + } +} + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -114,6 +128,9 @@ type endpoint struct { // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is // disabled. gsoMaxSize uint32 + + // wg keeps track of running goroutines. + wg sync.WaitGroup } // Options specify the details about the fd-based endpoint to be created. @@ -148,6 +165,9 @@ type Options struct { // disabled. GSOMaxSize uint32 + // SoftwareGSOEnabled indicates whether software GSO is enabled or not. + SoftwareGSOEnabled bool + // PacketDispatchMode specifies the type of inbound dispatcher to be // used for this endpoint. PacketDispatchMode PacketDispatchMode @@ -161,11 +181,20 @@ type Options struct { RXChecksumOffload bool } +// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT +// support in the host kernel. This allows us to use multiple FD's to receive +// from the same underlying NIC. The fanoutID needs to be the same for a given +// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT +// option for an FD with a fanoutID already in use by another FD for a different +// NIC will return an EINVAL. +var fanoutID = 1 + // New creates a new fd-based endpoint. // // Makes fd non-blocking, but does not take ownership of fd, which must remain -// open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +219,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,27 +236,35 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO + if opts.SoftwareGSOEnabled { + e.caps |= stack.CapabilitySoftwareGSO + } else { + e.caps |= stack.CapabilityHardwareGSO + } e.gsoMaxSize = opts.GSOMaxSize } } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + // Increment fanoutID to ensure that we don't re-use the same fanoutID for + // the next endpoint. + fanoutID++ + + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -247,7 +284,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher case *unix.SockaddrLinklayer: // enable PACKET_FANOUT mode is the underlying socket is // of type AF_PACKET. - const fanoutID = 1 const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { @@ -290,7 +326,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // saved, they stop sending outgoing packets and all incoming packets // are rejected. for i := range e.inboundDispatchers { - go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) } } @@ -320,6 +360,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + // virtioNetHdr is declared in linux/virtio_net.h. type virtioNetHdr struct { flags uint8 @@ -340,10 +386,11 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -358,17 +405,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen eth.Encode(ethHdr) } - if e.Capabilities()&stack.CapabilityGSO != 0 { + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -381,18 +428,151 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + } + + if pkt.Data.Size() == 0 { + return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) + } + + return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) +} + +// WritePackets writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var ethHdrBuf []byte + // hdr + data + iovLen := 2 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ + } + + n := len(pkts) + + views := pkts[0].Data.Views() + /* + * Each bondary in views can add one more iovec. + * + * payload | | | | + * ----------------------------- + * packets | | | | | | | + * ----------------------------- + * iovecs | | | | | | | | | + */ + iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) + mmsgHdrs := make([]rawfile.MMsgHdr, n) + + iovecIdx := 0 + viewIdx := 0 + viewOff := 0 + off := 0 + nextOff := 0 + for i := range pkts { + // TODO(b/134618279): Different packets may have different data + // in the future. We should handle this. + if !viewsEqual(pkts[i].Data.Views(), views) { + panic("All packets in pkts should have the same Data.") + } + + prevIovecIdx := iovecIdx + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovec[iovecIdx] + packetSize := pkts[i].DataSize + hdr := &pkts[i].Header + + off = pkts[i].DataOffset + if off != nextOff { + // We stop in a different point last time. + size := packetSize + viewIdx = 0 + viewOff = 0 + for size > 0 { + if size >= len(views[viewIdx]) { + viewIdx++ + viewOff = 0 + size -= len(views[viewIdx]) + } else { + viewOff = size + size = 0 + } + } + } + nextOff = off + packetSize + + if ethHdrBuf != nil { + v := &iovec[iovecIdx] + v.Base = ðHdrBuf[0] + v.Len = uint64(len(ethHdrBuf)) + iovecIdx++ + } + + v := &iovec[iovecIdx] + hdrView := hdr.View() + v.Base = &hdrView[0] + v.Len = uint64(len(hdrView)) + iovecIdx++ + + for packetSize > 0 { + vec := &iovec[iovecIdx] + iovecIdx++ + + v := views[viewIdx] + vec.Base = &v[viewOff] + s := len(v) - viewOff + if s <= packetSize { + viewIdx++ + viewOff = 0 + } else { + s = packetSize + viewOff += s + } + vec.Len = uint64(s) + packetSize -= s + } + + mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + packets := 0 + for packets < n { + sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] } + return packets, nil +} - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) +// viewsEqual tests whether v1 and v2 refer to the same backing bytes. +func viewsEqual(vs1, vs2 []buffer.View) bool { + return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) } -// WriteRawPacket writes a raw packet directly to the file descriptor. -func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) +} + +// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } @@ -429,20 +609,18 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } -// Inject injects an inbound packet. -func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound injects an inbound packet. +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..2066987eb 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents tcpip.PacketBuffer } type context struct { @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, @@ -93,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.ch <- packetInfo{remote, protocol, vv.ToView()} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -169,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -259,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.VectorisedView{}, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -294,11 +299,12 @@ func TestDeliverPacket(t *testing.T) { b[i] = uint8(rand.Intn(256)) } + var hdr header.Ethernet if !eth { // So that it looks like an IPv4 packet. b[0] = 0x40 } else { - hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr = make(header.Ethernet, header.EthernetMinimumSize) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, @@ -316,14 +322,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents.Data will be a single + // view, so make pi do the same for the + // DeepEqual check. + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 8bfeb97e4..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) + eth = header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -189,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 7ca217e5b..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} d.views = make([]buffer.View, len(BufConfig)) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { iovLen++ } d.iovecs = make([]syscall.Iovec, iovLen) @@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { func (d *readVDispatcher) allocateViews(bufConfig []int) { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { if err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // Skip virtioNetHdr which is added before each packet, it // isn't used and it isn't in a view. n -= virtioNetHdrSize @@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[0]) + eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -194,7 +198,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { } d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // virtioNetHdr is prepended before each packet. iovLen++ } @@ -225,7 +229,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { for k := 0; k < len(d.views); k++ { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -261,7 +265,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } if n <= d.e.hdrSize { @@ -271,9 +275,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -291,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 47a54845c..f35fcdff4 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -6,10 +6,11 @@ go_library( name = "loopback", srcs = ["loopback.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..499cc608f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -23,6 +23,7 @@ package loopback import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -32,8 +33,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- @@ -70,18 +71,45 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { + return tcpip.ErrBadAddress + } + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + vv.TrimFront(len(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: vv, + LinkHeader: buffer.View(linkHeader), + }) return nil } diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index ea12ef1ac..1ac7948b6 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,9 +7,7 @@ go_library( name = "muxed", srcs = ["injectable.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..445b22c17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -79,35 +79,59 @@ func (m *InjectableEndpoint) IsAttached() bool { return m.dispatcher != nil } -// Inject implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound implements stack.InjectableLinkEndpoint. +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +} + +// WritePackets writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if +// r.RemoteAddress has a route registered in this endpoint. +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + endpoint, ok := m.routes[r.RemoteAddress] + if !ok { + return 0, tcpip.ErrNoRoute + } + return endpoint.WritePackets(r, gso, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } -// WriteRawPacket writes outbound packets to the appropriate +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // WriteRawPacket doesn't get a route or network address, so there's + // nowhere to write this. + return tcpip.ErrNoRoute +} + +// InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { endpoint, ok := m.routes[dest] if !ok { return tcpip.ErrNoRoute } - return endpoint.WriteRawPacket(dest, packet) + return endpoint.InjectOutbound(dest, packet) +} + +// Wait implements stack.LinkEndpoint.Wait. +func (m *InjectableEndpoint) Wait() { + for _, ep := range m.routes { + ep.Wait() + } } // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..63b249837 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -31,7 +31,7 @@ import ( func TestInjectableEndpointRawDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - endpoint.WriteRawPacket(dstIP, []byte{0xFA}) + endpoint.InjectOutbound(dstIP, []byte{0xFA}) buf := make([]byte, ipv4.MaxTotalSize) bytesRead, err := sock.Read(buf) @@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewView(0).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { @@ -87,8 +91,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 2e8bc772a..d8211e93d 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -13,8 +13,9 @@ go_library( "rawfile_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@org_golang_x_sys//unix:go_default_library", ], - deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..0b5a6cf49 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 7e286a3a6..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -22,6 +22,7 @@ import ( "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { return nil } +// NonBlockingSendMMsg sends multiple messages on a socket. +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { + n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) + if e != 0 { + return 0, TranslateErrno(e) + } + + return int(n), nil +} + // PollEvent represents the pollfd structure passed to a poll() system call. type PollEvent struct { FD int32 diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f2998aa98..a4f9cdd69 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -11,9 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 94725cb11..6b5bc542c 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -11,7 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 160a8f864..8c9234d54 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -9,7 +10,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip/link/sharedmem/pipe", diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..080f9d667 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. @@ -132,7 +132,8 @@ func (e *endpoint) Close() { } } -// Wait waits until all workers have stopped after a Close() call. +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. func (e *endpoint) Wait() { e.completed.Wait() } @@ -184,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -198,10 +200,30 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } eth.Encode(ethHdr) - v := payload.ToView() + v := pkt.Data.ToView() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(pkt.Header.View(), v) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(v, buffer.View{}) e.mu.Unlock() if !ok { @@ -252,8 +274,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..89603c48f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() { } type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + addr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + vv buffer.VectorisedView + linkHeader buffer.View } type testContext struct { @@ -119,23 +120,23 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: vv.Clone(nil), + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() @@ -272,7 +273,10 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -341,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -395,7 +401,10 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -410,7 +419,10 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -435,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -455,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -470,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -493,7 +514,10 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -509,7 +533,10 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -534,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -547,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: uu, + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -555,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 1756114e6..d6ae0368a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -9,9 +9,7 @@ go_library( "sniffer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..3392b7edd 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -49,6 +49,13 @@ var LogPackets uint32 = 1 // LogPacketsToFile must be accessed atomically. var LogPacketsToFile uint32 = 1 +var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{ + header.ICMPv4ProtocolNumber: header.IPv4MinimumSize, + header.ICMPv6ProtocolNumber: header.IPv6MinimumSize, + header.UDPProtocolNumber: header.UDPMinimumSize, + header.TCPProtocolNumber: header.TCPMinimumSize, +} + type endpoint struct { dispatcher stack.NetworkDispatcher lower stack.LinkEndpoint @@ -58,10 +65,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,33 +109,33 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +154,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -193,22 +200,19 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket implements the stack.LinkEndpoint interface. It is called by -// higher-level protocols to write packets; it just logs the packet and forwards -// the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) + logPacket("send", protocol, pkt.Header.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() + hdrBuf := pkt.Header.View() + length := len(hdrBuf) + pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { panic(err) } if len(hdrBuf) > length { @@ -218,28 +222,80 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen panic(err) } length -= len(hdrBuf) - if length > 0 { - for _, v := range payload.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - break - } - } + logVectorisedView(pkt.Data, length, buf) + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) } + } +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.dumpPacket(gso, protocol, pkt) + return e.lower.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + view := pkts[0].Data.ToView() + for _, pkt := range pkts { + e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), + }) + } + return e.lower.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + length := vv.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + panic(err) + } + logVectorisedView(vv, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } - return e.lower.WritePacket(r, gso, hdr, payload, protocol) + return e.lower.WriteRawPacket(vv) +} + +func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { + if length <= 0 { + return + } + for _, v := range vv.Views() { + if len(v) > length { + v = v[:length] + } + n, err := buf.Write(v) + if err != nil { + panic(err) + } + length -= n + if length == 0 { + return + } + } } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 @@ -284,6 +340,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie return } + // We aren't guaranteed to have a transport header - it's possible for + // writes via raw endpoints to contain only network headers. + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) + return + } + // Figure out the transport layer info. transName := "unknown" srcPort := uint16(0) diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 92dce8fac..a71a493fc 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -6,7 +6,5 @@ go_library( name = "tun", srcs = ["tun_unsafe.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 2597d4b3e..134837943 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -8,9 +9,7 @@ go_library( "waitable.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/gate", "//pkg/tcpip", diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..a8de38979 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,23 +40,22 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -100,12 +99,36 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) + e.writeGate.Leave() + return err +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return len(pkts), nil + } + + n, err := e.lower.WritePackets(r, gso, pkts, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } @@ -121,3 +144,6 @@ func (e *Endpoint) WaitWrite() { func (e *Endpoint) WaitDispatch() { e.dispatchGate.Close() } + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..31b11a27a 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -65,31 +65,45 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += len(pkts) + return len(pkts), nil +} + +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + e.writeCount++ + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*countedEndpoint) Wait() {} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -97,7 +111,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -106,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } @@ -139,7 +153,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..e7617229b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,9 +7,7 @@ go_library( name = "arp", srcs = ["arp.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index fd6395fc1..da8482509 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -16,9 +16,9 @@ // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. // -// To use it in the networking stack, pass arp.ProtocolName as one of the -// network protocols when calling stack.New. Then add an "arp" address to -// every NIC on the stack that should respond to ARP requests. That is: +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: // // if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { // // handle err @@ -33,9 +33,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ARP protocol name. - ProtocolName = "arp" - // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber @@ -45,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -61,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -82,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return @@ -100,19 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) + fallthrough // also fill the cache from requests case header.ARPReply: + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -129,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil @@ -159,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -200,8 +210,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -44,14 +44,19 @@ type testContext struct { } func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } @@ -97,19 +102,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..acf1e022c 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "reassembler_list", @@ -24,9 +25,10 @@ go_library( "reassembler_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/tcpip", "//pkg/tcpip/buffer", ], ) @@ -42,11 +44,3 @@ go_test( embed = [":fragmentation"], deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1628a82be..6da5238ec 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,6 +17,7 @@ package fragmentation import ( + "fmt" "log" "sync" "time" @@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t // Process processes an incoming fragment belonging to an ID // and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { @@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } f.mu.Unlock() - res, done, consumed := r.process(first, last, more, vv) - + res, done, consumed, err := r.process(first, last, more, vv) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r) + f.mu.Unlock() + return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + } f.mu.Lock() f.size += consumed if done { @@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } } f.mu.Unlock() - return res, done + return res, done, nil } func (f *Fragmentation) release(r *reassembler) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 799798544..72c0f53be 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) { t.Run(c.comment, func(t *testing.T) { f := NewFragmentation(1024, 512, DefaultReassembleTimeout) for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + if err != nil { + t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + } if !reflect.DeepEqual(vv, c.out[i].vv) { t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) } @@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) { time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done := f.Process(0, 1, 1, false, vv(1, "1")) + _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + if err != nil { + t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + } if done { t.Errorf("Fragmentation does not respect the reassembling timeout.") } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 8037f734b..9e002e396 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) } - return res, true, consumed + return res, true, consumed, nil } func (r *reassembler) tooOld(timeout time.Duration) bool { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4b3bd74fa..4144a7837 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -144,32 +144,47 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -182,7 +197,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -221,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -261,7 +282,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -349,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -412,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -451,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -491,7 +523,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -501,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -552,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -599,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..aeddfcdd4 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -9,9 +10,7 @@ go_library( "ipv4.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index a25756443..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } @@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7a06f525..e645cf62c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -14,9 +14,9 @@ // Package ipv4 contains the implementation of the ipv4 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv4.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv4 @@ -32,9 +32,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv4 protocol name. - ProtocolName = "ipv4" - // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -42,28 +39,33 @@ const ( // TotalLength field of the ipv4 header. MaxTotalSize = 0xffff + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + // buckets is the number of identifier buckets. buckets = 2048 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation + protocol *protocol } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, } return e, nil @@ -71,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi // DefaultTTL is the default time-to-live value for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -87,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -115,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -135,122 +138,167 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payload.Size()) + length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: ttl, - Protocol: uint8(protocol), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() return nil } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -264,10 +312,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) } ip.SetID(uint16(id)) } @@ -277,37 +325,63 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { + if pkt.Data.Size() == 0 { + // Drop the packet as it's marked as a fragment but has + // no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < h.FragmentOffset() { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } var ready bool - vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + var err error + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } if !ready { return } @@ -315,24 +389,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} -type protocol struct{} +type protocol struct { + ids []uint32 + hashIV uint32 -// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv4 protocol number. @@ -358,12 +432,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -378,7 +474,7 @@ func calculateMTU(mtu uint32) uint32 { // hashRoute calculates a hash value for the given route. It uses the source & // destination address, the transport protocol number, and a random initial // value (generated once on initialization) to generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { t := r.LocalAddress a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = r.RemoteAddress @@ -386,22 +482,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -var ( - ids []uint32 - hashIV uint32 -) - -func init() { - ids = make([]uint32, buckets) +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. r := hash.RandN32(1 + buckets) for i := range ids { ids[i] = r[i] } - hashIV = r[buckets] + hashIV := r[buckets] - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -33,24 +33,20 @@ import ( ) func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, @@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -177,28 +173,19 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, - Ch: make(chan packetInfo, size), +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec -} - -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView } // Drain removes all outbound packets from the channel and counts them. @@ -215,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -241,10 +223,11 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +249,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } @@ -298,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -351,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -368,3 +357,119 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestInvalidFragments(t *testing.T) { + // These packets have both IHL and TotalLength set to 0. + testCases := []struct { + name string + packets [][]byte + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + "ihl_totallen_zero_valid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + "ihl_totallen_zero_invalid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + // Total Length of 37(20 bytes IP header + 17 bytes of + // payload) + // Frag Offset of 0x1ffe = 8190*8 = 65520 + // Leading to the fragment end to be past 65535. + "ihl_totallen_valid_invalid_frag_offset_1", + [][]byte{ + {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + // The following 3 tests were found by running a fuzzer and were + // triggering a panic in the IPv4 reassembler code. + { + "ihl_less_than_ipv4_minimum_size_1", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_2", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_3", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "fragment_with_short_total_len_extra_payload", + [][]byte{ + {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + { + "multiple_fragments_with_more_fragments_set_to_false", + [][]byte{ + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + 1, + 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const nicID tcpip.NICID = 42 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + }, + }) + + var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) + var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) + ep := channel.New(10, 1500, linkAddr) + s.CreateNIC(nicID, sniffer.New(ep)) + + for _, pkt := range tc.packets { + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..e4e273460 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -9,9 +10,7 @@ go_library( "ipv6.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -25,6 +24,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b4d0295bf..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -21,21 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -const ( - // ndpHopLimit is the expected IP hop limit value of 255 for received - // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, - // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently - // drop the NDP packet. All outgoing NDP packets must use this value for - // its IP hop limit field. - ndpHopLimit = 255 -) - // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -61,35 +52,54 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } h := header.ICMPv6(v) + iph := header.IPv6(netHeader) + + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, header.ICMPv6RouterSolicit, header.ICMPv6RouterAdvert, header.ICMPv6RedirectMsg: - if header.IPv6(netHeader).HopLimit() != ndpHopLimit { + if iph.HopLimit() != header.NDPHopLimit { + received.Invalid.Increment() + return + } + + if h.Code() != 0 { received.Invalid.Increment() return } @@ -103,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -113,33 +123,80 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + + ns := header.NDPNeighborSolicit(h.NDPPayload()) + targetAddr := ns.TargetAddress() + s := r.Stack() + rxNICID := r.NICID() + + isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // If the target address is tentative and the source + // of the packet is a unicast (specified) address, then + // the source of the packet is attempting to perform + // address resolution on the target. In this case, the + // solicitation is silently ignored, as per RFC 4862 + // section 5.4.3. + // + // If the target address is tentative and the source of + // the packet is the unspecified address (::), then we + // know another node is also performing DAD for the + // same address (since targetAddr is tentative for us, + // we know we are also performing DAD on it). In this + // case we let the stack know so it can handle such a + // scenario and do nothing further with the NDP NS. + if iph.SourceAddress() == header.IPv6Any { + s.DupTentativeAddrDetected(rxNICID, targetAddr) + } + + // Do not handle neighbor solicitations targeted + // to an address that is tentative on the received + // NIC any further. + return + } + + // At this point we know that targetAddr is not tentative on + // rxNICID so the packet is processed as defined in RFC 4861, + // as per RFC 4862 section 5.4.3. + + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag - copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr) - pkt[icmpV6OptOffset] = ndpOptDstLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -152,9 +209,26 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + // TODO(tamird/ghanan): there exists an explicit NDP option that is + // used to update the neighbor table with link addresses for a + // neighbor from an NS (see the Source Link Layer option RFC + // 4861 section 4.6.1 and section 7.2.3). + // + // Furthermore, the entirety of NDP handling here seems to be + // contradicted by RFC 4861. + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) + + // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) + // + // 7.1.2. Validation of Neighbor Advertisements + // + // The IP Hop Limit field has a value of 255, i.e., the packet + // could not possibly have been forwarded by a router. + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -166,10 +240,45 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + + na := header.NDPNeighborAdvert(h.NDPPayload()) + targetAddr := na.TargetAddress() + stack := r.Stack() + rxNICID := r.NICID() + + isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // We just got an NA from a node that owns an address we + // are performing DAD on, implying the address is not + // unique. In this case we let the stack know so it can + // handle such a scenario and do nothing furthur with + // the NDP NA. + stack.DupTentativeAddrDetected(rxNICID, targetAddr) + return + } + + // At this point we know that the targetAddress is not tentative + // on rxNICID. However, targetAddr may still be assigned to + // rxNICID but not tentative (it could be permanent). Such a + // scenario is beyond the scope of RFC 4862. As such, we simply + // ignore such a scenario for now and proceed as normal. + // + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. + + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: @@ -178,14 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -197,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -209,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() @@ -262,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: defaultIPv6HopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,7 +15,6 @@ package ipv6 import ( - "fmt" "reflect" "strings" "testing" @@ -31,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -56,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -66,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -81,10 +80,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li } func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -130,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -142,11 +143,13 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -177,13 +180,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { t := v.Type() for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) - case reflect.Struct: + if s, ok := v.Interface().(*tcpip.StatCounter); ok { + f(t.Field(i).Name, s) + } else { visitStats(v, f) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -206,41 +206,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { @@ -279,20 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -364,3 +363,537 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331a8bdaa..dd31f0fb7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -14,13 +14,15 @@ // Package ipv6 contains the implementation of the ipv6 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv6.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv6 import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -28,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv6 protocol name. - ProtocolName = "ipv6" - // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -38,23 +37,24 @@ const ( // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff - // defaultIPv6HopLimit is the default hop limit for IPv6 Packets - // egressed by Netstack. - defaultIPv6HopLimit = 255 + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + protocol *protocol } // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -97,25 +97,37 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { - length := uint16(hdr.UsedLength() + payload.Size()) +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, - NextHeader: uint8(protocol), - HopLimit: ttl, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -123,49 +135,68 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + hdr := &pkts[i].Header + size := pkts[i].DataSize + ip := e.addIPHeader(r, hdr, size, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/119580726): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} -type protocol struct{} - -// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv6 protocol number. @@ -190,25 +221,48 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + protocol: p, }, nil } // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -221,8 +275,7 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..1cbfa7278 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,270 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveUDP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -31,16 +32,18 @@ import ( func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -95,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -107,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -148,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) { // Receive the NDP packet with an invalid hop limit // value. - handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) // Invalid count should have increased. if got := invalid.Value(); got != 1 { @@ -162,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) { } // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) // Rx count of NDP packet of type typ.typ should have // increased. @@ -177,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..ab24372e7 --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..ad3cc24fa --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 989058413..e156b01f6 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +7,7 @@ go_library( name = "ports", srcs = ["ports.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..6c5e19e8f 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -19,6 +19,7 @@ import ( "math" "math/rand" "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,6 +28,10 @@ const ( // FirstEphemeral is the first ephemeral port. FirstEphemeral = 16000 + // numEphemeralPorts it the mnumber of available ephemeral ports to + // Netstack. + numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 + anyIPAddress tcpip.Address = "" ) @@ -36,54 +41,182 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool +} + +func (f Flags) bits() reuseFlag { + var rf reuseFlag + if f.MostRecent { + rf |= mostRecentFlag + } + if f.LoadBalanced { + rf |= loadBalancedFlag + } + return rf +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex allocatedPorts map[portDescriptor]bindAddresses + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 } +type reuseFlag int + +const ( + mostRecentFlag reuseFlag = 1 << iota + loadBalancedFlag + nextFlag + + flagMask = nextFlag - 1 +) + type portNode struct { - reuse bool - refs int + // refs stores the count for each possible flag combination. + refs [nextFlag]int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode +func (p portNode) totalRefs() int { + var total int + for _, r := range p.refs { + total += r + } + return total +} -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true +// flagRefs returns the number of references with all specified flags. +func (p portNode) flagRefs(flags reuseFlag) int { + var total int + for i, r := range p.refs { + if reuseFlag(i)&flags == flags { + total += r } - if !reuse { + } + return total +} + +// allRefsHave returns if all references have all specified flags. +func (p portNode) allRefsHave(flags reuseFlag) bool { + for i, r := range p.refs { + if reuseFlag(i)&flags == flags && r > 0 { return false } - for _, n := range b { - if !n.reuse { + } + return true +} + +// intersectionRefs returns the set of flags shared by all references. +func (p portNode) intersectionRefs() reuseFlag { + intersection := flagMask + for i, r := range p.refs { + if r > 0 { + intersection &= reuseFlag(i) + } + } + return intersection +} + +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode + +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { + flagBits := flags.bits() + if bindToDevice == 0 { + // Trying to binding all devices. + if flagBits == 0 { + // Can't bind because the (addr,port) is already bound. + return false + } + intersection := flagMask + for _, p := range d { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + intersection := flagMask + + if p, ok := d[0]; ok { + intersection = p.intersectionRefs() + if intersection&flagBits == 0 { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + return false + } + } + + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(flags, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(flags, bindToDevice) { return false } } - if n, ok := b[addr]; ok { - if !reuse { + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(flags, bindToDevice) { return false } - return n.reuse } + return true } @@ -97,11 +230,40 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { - count := uint16(math.MaxUint16 - FirstEphemeral + 1) - offset := uint16(rand.Int31n(int32(count))) + offset := uint32(rand.Int31n(numEphemeralPorts)) + return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +} + +// portHint atomically reads and returns the s.hint value. +func (s *PortManager) portHint() uint32 { + return atomic.LoadUint32(&s.hint) +} + +// incPortHint atomically increments s.hint by 1. +func (s *PortManager) incPortHint() { + atomic.AddUint32(&s.hint, 1) +} + +// PickEphemeralPortStable starts at the specified offset + s.portHint and +// iterates over all ephemeral ports, allowing the caller to decide whether a +// given port is suitable for its needs and stopping when a port is found or an +// error occurs. +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) + if err == nil { + s.incPortHint() + } + return p, err - for i := uint16(0); i < count; i++ { - port = FirstEphemeral + (offset+i)%count +} + +// pickEphemeralPort starts at the offset specified from the FirstEphemeral port +// and iterates over the number of ports specified by count and allows the +// caller to decide whether a given port is suitable for its needs, and stopping +// when a port is found or an error occurs. +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + for i := uint32(0); i < count; i++ { + port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) if err != nil { return 0, err @@ -116,17 +278,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, flags, bindToDevice) { return false } } @@ -138,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } + flagBits := flags.bits() // Reserve port on all network protocols. for _, network := range networks { @@ -171,12 +334,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { - n.refs++ - m[addr] = n - } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d } + n := d[bindToDevice] + n.refs[flagBits]++ + d[bindToDevice] = n } return true @@ -184,22 +349,30 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() + flagBits := flags.bits() + for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] if !ok { continue } - n.refs-- - if n.refs == 0 { + n, ok := d[bindToDevice] + if !ok { + continue + } + n.refs[flagBits]-- + d[bindToDevice] = n + if n.refs == [nextFlag]int{} { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..d6969d050 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,6 +15,7 @@ package ports import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -32,8 +33,9 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -48,7 +50,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -59,7 +61,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -69,36 +71,206 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, + }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "binding with reuseport and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuseport and not reuseport by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuseport and not reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, + }, + }, { + tname: "bind twice with reuseport once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, }, } { @@ -108,12 +280,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -125,7 +297,6 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - pm := NewPortManager() customErr := &tcpip.Error{} for _, test := range []struct { name string @@ -169,9 +340,63 @@ func TestPickEphemeralPort(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) } }) } } + +func TestPickEphemeralPortStable(t *testing.T) { + customErr := &tcpip.Error{} + for _, test := range []struct { + name string + f func(port uint16) (bool, *tcpip.Error) + wantErr *tcpip.Error + wantPort uint16 + }{ + { + name: "no-port-available", + f: func(port uint16) (bool, *tcpip.Error) { + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + { + name: "port-tester-error", + f: func(port uint16) (bool, *tcpip.Error) { + return false, customErr + }, + wantErr: customErr, + }, + { + name: "only-port-16042-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port == FirstEphemeral+42 { + return true, nil + } + return false, nil + }, + wantPort: FirstEphemeral + 42, + }, + { + name: "only-port-under-16000-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port < FirstEphemeral { + return true, nil + } + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + } { + t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { + t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + } + }) + } +} diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index a57752a7c..d7496fde6 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_connect", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index e2021cd15..2239c1e66 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -126,7 +126,10 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { @@ -138,11 +141,11 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) + linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil { + if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index dad8ef399..875561566 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_echo", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 1716be285..bca73cbb1 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -111,7 +111,10 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { @@ -128,7 +131,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{ + linkEP, err := fdbased.New(&fdbased.Options{ FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, @@ -137,7 +140,7 @@ func main() { if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, linkID); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 76b5f4ffa..b31ddba2f 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,12 +1,10 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "seqnum", srcs = ["seqnum.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 788de3dfe..69077669a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "linkaddrentry_list", @@ -21,6 +22,7 @@ go_library( "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", + "ndp.go", "nic.go", "registration.go", "route.go", @@ -29,11 +31,10 @@ go_library( "transport_demuxer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/ilist", + "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -51,18 +52,26 @@ go_test( name = "stack_x_test", size = "small", srcs = [ + "ndp_test.go", "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ ":stack", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) @@ -76,11 +85,3 @@ go_test( "//pkg/tcpip", ], ) - -filegroup( - name = "autogen", - srcs = [ - "linkaddrentry_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index f8156be47..3a20839da 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "golang.org/x/time/rate" ) @@ -33,54 +31,11 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - mu sync.RWMutex - l *rate.Limiter + *rate.Limiter } // NewICMPRateLimiter returns a global rate limiter for controlling the rate // at which ICMP messages are generated by the stack. func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} -} - -// Allow returns true if we are allowed to send at least 1 message at the -// moment. -func (i *ICMPRateLimiter) Allow() bool { - i.mu.RLock() - allow := i.l.Allow() - i.mu.RUnlock() - return allow -} - -// Limit returns the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) Limit() rate.Limit { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Limit() -} - -// SetLimit sets the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { - i.mu.RLock() - defer i.mu.RUnlock() - i.l.SetLimit(newLimit) -} - -// Burst returns how many ICMP messages can be sent at any single instant. -func (i *ICMPRateLimiter) Burst() int { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Burst() -} - -// SetBurst sets the maximum number of ICMP messages allowed at any single -// instant. -// -// NOTE: Changing Burst causes the underlying rate limiter to be recreated. -func (i *ICMPRateLimiter) SetBurst(burst int) { - i.mu.Lock() - i.l = rate.NewLimiter(i.l.Limit(), burst) - i.mu.Unlock() + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go new file mode 100644 index 000000000..90664ba8a --- /dev/null +++ b/pkg/tcpip/stack/ndp.go @@ -0,0 +1,1157 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "log" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor + // Solicitation messages to send when doing Duplicate Address Detection + // for a tentative address. + // + // Default = 1 (from RFC 4862 section 5.1) + defaultDupAddrDetectTransmits = 1 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending NDP Neighbor solicitation messages. + // + // Default = 1s (from RFC 4861 section 10). + defaultRetransmitTimer = time.Second + + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + // + // Default = true. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements, as a host. + // + // Default = true. + defaultDiscoverDefaultRouters = true + + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + // + // Default = true. + defaultDiscoverOnLinkPrefixes = true + + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending NDP Neighbor solicitation messages. Note, RFC 4861 does + // not impose a minimum Retransmit Timer, but we do here to make sure + // the messages are not sent all at once. We also come to this value + // because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. + // Note, the unit of the RetransmitTimer field in the Router + // Advertisement is milliseconds. + // + // Min = 1ms. + minimumRetransmitTimer = time.Millisecond + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + // + // Max = 10. + MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + // + // Max = 10. + MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour +) + +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus will be called when the DAD process + // for an address (addr) on a NIC (with ID nicID) completes. resolved + // will be set to true if DAD completed successfully (no duplicate addr + // detected); false otherwise (addr was detected to be a duplicate on + // the link the NIC is a part of, or it was stopped for some other + // reason, such as the address being removed). If an error occured + // during DAD, err will be set and resolved must be ignored. + // + // This function is permitted to block indefinitely without interfering + // with the stack's operation. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true if the newly discovered + // router should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + + // OnDefaultRouterInvalidated will be called when a discovered default + // router that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + + // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + + // OnOnLinkPrefixInvalidated will be called when a discovered on-link + // prefix that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) +} + +// NDPConfigurations is the NDP configurations for the netstack. +type NDPConfigurations struct { + // The number of Neighbor Solicitation messages to send when doing + // Duplicate Address Detection for a tentative address. + // + // Note, a value of zero effectively disables DAD. + DupAddrDetectTransmits uint8 + + // The amount of time to wait between sending Neighbor solicitation + // messages. + // + // Must be greater than 0.5s. + RetransmitTimer time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes + // will be discovered from Router Advertisements' Prefix Information + // option. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool +} + +// DefaultNDPConfigurations returns an NDPConfigurations populated with +// default values. +func DefaultNDPConfigurations() NDPConfigurations { + return NDPConfigurations{ + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + } +} + +// validate modifies an NDPConfigurations with valid values. If invalid values +// are present in c, the corresponding default values will be used instead. +// +// If RetransmitTimer is less than minimumRetransmitTimer, then a value of +// defaultRetransmitTimer will be used. +func (c *NDPConfigurations) validate() { + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } +} + +// ndpState is the per-interface NDP state. +type ndpState struct { + // The NIC this ndpState is for. + nic *NIC + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + + // The DAD state to send the next NS message, or resolve the address. + dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState +} + +// dadState holds the Duplicate Address Detection timer and channel to signal +// to the DAD goroutine that DAD should stop. +type dadState struct { + // The DAD timer to send the next NS message, or resolve the address. + timer *time.Timer + + // Used to let the DAD timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this dadState is associated with. + done *bool +} + +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement (RA). +type defaultRouterState struct { + invalidationTimer *time.Timer + + // Used to inform the timer not to invalidate the default router (R) in + // a race condition (T1 is a goroutine that handles an RA from R and T2 + // is the goroutine that handles R's invalidation timer firing): + // T1: Receive a new RA from R + // T1: Obtain the NIC's lock before processing the RA + // T2: R's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends R's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates R immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate R, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition (T1 is a goroutine that handles a PI for P and T2 + // is the goroutine that handles P's invalidation timer firing): + // T1: Receive a new PI for P + // T1: Obtain the NIC's lock before processing the PI + // T2: P's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends P's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates P immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate P, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + +// startDuplicateAddressDetection performs Duplicate Address Detection. +// +// This function must only be called by IPv6 addresses that are currently +// tentative. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { + // addr must be a valid unicast IPv6 address. + if !header.IsV6UnicastAddress(addr) { + return tcpip.ErrAddressFamilyNotSupported + } + + // Should not attempt to perform DAD on an address that is currently in + // the DAD process. + if _, ok := ndp.dad[addr]; ok { + // Should never happen because we should only ever call this + // function for newly created addresses. If we attemped to + // "add" an address that already existed, we would returned an + // error since we attempted to add a duplicate address, or its + // reference count would have been increased without doing the + // work that would have been done for an address that was brand + // new. See NIC.addPermanentAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + } + + remaining := ndp.configs.DupAddrDetectTransmits + + { + done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil { + return err + } + if done { + return nil + } + } + + remaining-- + + var done bool + var timer *time.Timer + timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { + var d bool + var err *tcpip.Error + + // doDadIteration does a single iteration of the DAD loop. + // + // Returns true if the integrator needs to be informed of DAD + // completing. + doDadIteration := func() bool { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if done { + // If we reach this point, it means that the DAD + // timer fired after another goroutine already + // obtained the NIC lock and stopped DAD before + // this function obtained the NIC lock. Simply + // return here and do nothing further. + return false + } + + ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] + if !ok { + // This should never happen. + // We should have an endpoint for addr since we + // are still performing DAD on it. If the + // endpoint does not exist, but we are doing DAD + // on it, then we started DAD at some point, but + // forgot to stop it when the endpoint was + // deleted. + panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) + } + + d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil || d { + delete(ndp.dad, addr) + + if err != nil { + log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + + // Let the integrator know DAD has completed. + return true + } + + remaining-- + timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + return false + } + + if doDadIteration() && ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + } + }) + + ndp.dad[addr] = dadState{ + timer: timer, + done: &done, + } + + return nil +} + +// doDuplicateAddressDetection is called on every iteration of the timer, and +// when DAD starts. +// +// It handles resolving the address (if there are no more NS to send), or +// sending the next NS if there are more NS to send. +// +// This function must only be called by IPv6 addresses that are currently +// tentative. +// +// The NIC that ndp belongs to (n) MUST be locked. +// +// Returns true if DAD has resolved; false if DAD is still ongoing. +func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative + // since we are still performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + } + + if remaining == 0 { + // DAD has resolved. + ref.setKind(permanent) + return true, nil + } + + // Send a new NS. + snmc := header.SolicitedNodeAddr(addr) + snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] + if !ok { + // This should never happen as if we have the + // address, we should have the solicited-node + // address. + panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) + } + + // Use the unspecified address as the source address when performing + // DAD. + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(addr) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { + sent.Dropped.Increment() + return false, err + } + sent.NeighborSolicit.Increment() + + return false, nil +} + +// stopDuplicateAddressDetection ends a running Duplicate Address Detection +// process. Note, this may leave the DAD process for a tentative address in +// such a state forever, unless some other external event resolves the DAD +// process (receiving an NA from the true owner of addr, or an NS for addr +// (implying another node is attempting to use addr)). It is up to the caller +// of this function to handle such a scenario. Normally, addr will be removed +// from n right after this function returns or the address successfully +// resolved. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { + dad, ok := ndp.dad[addr] + if !ok { + // Not currently performing DAD on addr, just return. + return + } + + if dad.timer != nil { + dad.timer.Stop() + dad.timer = nil + + *dad.done = true + dad.done = nil + } + + delete(ndp.dad, addr) + + // Let the integrator know DAD did not resolve. + if ndp.nic.stack.ndpDisp != nil { + go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + return + } + + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation timer. + timer := rtr.invalidationTimer + + // We should ALWAYS have an invalidation timer for a + // discovered router. + if timer == nil { + panic("ndphandlera: RA invalidation timer should not be nil") + } + + if !timer.Stop() { + // If we reach this point, then we know the + // timer fired after we already took the NIC + // lock. Inform the timer not to invalidate the + // router when it obtains the lock as we just + // got a new RA that refreshes its lifetime to a + // non-zero value. See + // defaultRouterState.doNotInvalidate for more + // details. + *rtr.doNotInvalidate = true + } + + timer.Reset(rl) + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + + case header.NDPPrefixInformation: + prefix := opt.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) + } + + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) + } + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationTimer.Stop() + rtr.invalidationTimer = nil + *rtr.doNotInvalidate = true + rtr.doNotInvalidate = nil + + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition. See defaultRouterState.doNotInvalidate for more + // details. + var doNotInvalidate bool + + ndp.defaultRouters[ip] = defaultRouterState{ + invalidationTimer: time.AfterFunc(rl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if doNotInvalidate { + doNotInvalidate = false + return + } + + ndp.invalidateDefaultRouter(ip) + }), + doNotInvalidate: &doNotInvalidate, + } +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition. See onLinkPrefixState.doNotInvalidate for more + // details. + var doNotInvalidate bool + var timer *time.Timer + + // Only create a timer if the lifetime is not infinite. + if l < header.NDPInfiniteLifetime { + timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + } + + ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ + invalidationTimer: timer, + doNotInvalidate: &doNotInvalidate, + } +} + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + if s.invalidationTimer != nil { + s.invalidationTimer.Stop() + s.invalidationTimer = nil + *s.doNotInvalidate = true + } + + s.doNotInvalidate = nil + + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + } +} + +// prefixInvalidationCallback returns a new on-link prefix invalidation timer +// for prefix that fires after vl. +// +// doNotInvalidate is used to signal the timer when it fires at the same time +// that a prefix's valid lifetime gets refreshed. See +// onLinkPrefixState.doNotInvalidate for more details. +func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateOnLinkPrefix(prefix) + }) +} + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for _, ref := range ndp.nic.endpoints { + if ref.protocol != header.IPv6ProtocolNumber { + continue + } + + if ref.configType != slaac { + continue + } + + addr := ref.ep.ID().LocalAddress + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // + // At this point, we know we are refreshing a SLAAC generated + // IPv6 address with the prefix, prefix. Do the work as outlined + // by RFC 4862 section 5.5.3.e. + // + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) + } + + // TODO(b/143713887): Handle deprecating auto-generated address + // after the preferred lifetime. + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the + // address generated by SLAAC is as follows: + // + // 1) If the received Valid Lifetime is greater than 2 hours or + // greater than RemainingLifetime, set the valid lifetime of + // the address to the advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, + // ignore the advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 + // hours. + + // Handle the infinite valid lifetime separately as we do not + // keep a timer in this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it + // is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + ndp.autoGenAddresses[addr] = addrState + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, + // assume the remaining time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + ndp.autoGenAddresses[addr] = addrState + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) + ndp.autoGenAddresses[addr] = addrState + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + addrBytes := make([]byte, header.IPv6AddressSize) + copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + }, FirstPrimaryEndpoint, permanent, slaac); err != nil { + panic(err) + } + + // Setup the timers to deprecate and invalidate this newly generated + // address. + + // TODO(b/143713887): Handle deprecating auto-generated addresses + // after the preferred lifetime. + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go new file mode 100644 index 000000000..666f86c33 --- /dev/null +++ b/pkg/tcpip/stack/ndp_test.go @@ -0,0 +1,2057 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "encoding/binary" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + linkAddr1 = "\x02\x02\x03\x04\x05\x06" + linkAddr2 = "\x02\x02\x03\x04\x05\x07" + linkAddr3 = "\x02\x02\x03\x04\x05\x08" + defaultTimeout = 100 * time.Millisecond +) + +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) +) + +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + var addr tcpip.AddressWithPrefix + if header.IsValidUnicastEthernetAddress(linkAddr) { + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + } + + return prefix, subnet, addr +} + +// TestDADDisabled tests that an address successfully resolves immediately +// when DAD is not enabled (the default for an empty stack.Options). +func TestDADDisabled(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Should get the address immediately since we should not have performed + // DAD on it. + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + } + + // We should not have sent any NDP NS messages. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + t.Fatalf("got NeighborSolicit = %d, want = 0", got) + } +} + +// ndpDADEvent is a set of parameters that was passed to +// ndpDispatcher.OnDuplicateAddressDetectionStatus. +type ndpDADEvent struct { + nicID tcpip.NICID + addr tcpip.Address + resolved bool + err *tcpip.Error +} + +type ndpRouterEvent struct { + nicID tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + +type ndpPrefixEvent struct { + nicID tcpip.NICID + prefix tcpip.Subnet + // true if prefix was discovered, false if invalidated. + discovered bool +} + +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + +var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) + +// ndpDispatcher implements NDPDispatcher so tests can know when various NDP +// related events happen for test purposes. +type ndpDispatcher struct { + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent +} + +// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicID, + addr, + resolved, + err, + } + } +} + +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + true, + } + } + + return n.rememberRouter +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + false, + } + } +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + true, + } + } + + return n.rememberPrefix +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + false, + } + } +} + +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + +// TestDADResolve tests that an address successfully resolves after performing +// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. +// Included in the subtests is a test to make sure that an invalid +// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +func TestDADResolve(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + {"1:1s:1s", 1, time.Second, time.Second}, + {"2:1s:1s", 2, time.Second, time.Second}, + {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + // 0s is an invalid RetransmitTimer timer and will be fixed to + // the default RetransmitTimer value of 1s. + {"1:0s:1s", 1, 0, time.Second}, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + } + opts.NDPConfigs.RetransmitTimer = test.retransTimer + opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit + + // Should have sent an NDP NS immediately. + if got := stat.Value(); got != 1 { + t.Fatalf("got NeighborSolicit = %d, want = 1", got) + + } + + // Address should not be considered bound to the NIC yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for the remaining time - some delta (500ms), to + // make sure the address is still not resolved. + const delta = 500 * time.Millisecond + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + } + + // Should not have sent any more NS messages. + if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) + } + + // Validate the sent Neighbor Solicitation messages. + for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { + p := <-e.C + + // Make sure its an IPv6 packet. + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + + // Check NDP packet. + checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr1))) + } + }) + } + +} + +// TestDADFail tests to make sure that the DAD process fails if another node is +// detected to be performing DAD on the same address (receive an NS message from +// a node doing DAD for the same address), or if another node is detected to own +// the address already (receive an NA message for the tentative address). +func TestDADFail(t *testing.T) { + tests := []struct { + name string + makeBuf func(tgt tcpip.Address) buffer.Prependable + getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "RxSolicit", + func(tgt tcpip.Address) buffer.Prependable { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + + return hdr + + }, + func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return s.NeighborSolicit + }, + }, + { + "RxAdvert", + func(tgt tcpip.Address) buffer.Prependable { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return hdr + + }, + func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return s.NeighborAdvert + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + } + opts.NDPConfigs.RetransmitTimer = time.Second * 2 + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Address should not be considered bound to the NIC yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Receive a packet to simulate multiple nodes owning or + // attempting to own the same address. + hdr := test.makeBuf(addr1) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + if got := stat.Value(); got != 1 { + t.Fatalf("got stat = %d, want = 1", got) + } + + // Wait for DAD to fail and make sure the address did + // not get resolved. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + }) + } +} + +// TestDADStop tests to make sure that the DAD process stops when an address is +// removed. +func TestDADStop(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Remove the address. This should stop DAD. + if err := s.RemoveAddress(1, addr1); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) + } + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + } +} + +// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NDP configurations using an invalid NICID. +func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) + } +} + +// TestSetNDPConfigurations tests that we can update and use per-interface NDP +// configurations without affecting the default NDP configurations or other +// interfaces' configurations. +func TestSetNDPConfigurations(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransmitTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + { + "OK", + 1, + time.Second, + time.Second, + }, + { + "Invalid Retransmit Timer", + 1, + 0, + time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + }) + + // This NIC(1)'s NDP configurations will be updated to + // be different from the default. + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Created before updating NIC(1)'s NDP configurations + // but updating NIC(1)'s NDP configurations should not + // affect other existing NICs. + if err := s.CreateNIC(2, e); err != nil { + t.Fatalf("CreateNIC(2) = %s", err) + } + + // Update the NDP configurations on NIC(1) to use DAD. + configs := stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + } + if err := s.SetNDPConfigurations(1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + } + + // Created after updating NIC(1)'s NDP configurations + // but the stack's default NDP configurations should not + // have been updated. + if err := s.CreateNIC(3, e); err != nil { + t.Fatalf("CreateNIC(3) = %s", err) + } + + // Add addresses for each NIC. + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + } + if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + } + + // Address should not be considered bound to NIC(1) yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should get the address on NIC(2) and NIC(3) + // immediately since we should not have performed DAD on + // it as the stack was configured to not do DAD by + // default and we only updated the NDP configurations on + // NIC(1). + addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + } + if addr.Address != addr2 { + t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + } + addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + } + if addr.Address != addr3 { + t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + } + + // Sleep until right (500ms before) before resolution to + // make sure the address didn't resolve on NIC(1) yet. + const delta = 500 * time.Millisecond + time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + } + }) + } +} + +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + ra := header.NDPRouterAdvert(pkt.NDPPayload()) + opts := ra.Options() + opts.Serialize(optSer) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} +} + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) +} + +// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix +// Information option. +// +// Note, raBufWithPI does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { + flags := uint8(0) + if onLink { + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 + } + + // A valid header.NDPPrefixInformation must be 30 bytes. + buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. + buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. + buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. + copy(buf[14:], prefix.Address) + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ + header.NDPPrefixInformation(buf[:]), + }) +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestRouterDiscovery(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + l3Lifetime := time.Duration(6) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + expectRouterEvent(llAddr3, true) + + // Rx an RA from lladdr2 with lesser lifetime. + l2Lifetime := time.Duration(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } + + // Wait for lladdr2's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + default: + } + } + } +} + +// TestNoPrefixDiscovery tests that prefix discovery will not be performed if +// configured not to. +func TestNoPrefixDiscovery(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + + // Being configured to discover prefixes means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // prefix discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) + + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered on-link prefix when the dispatcher asks it not to. +func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have received any prefix events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestPrefixDiscovery(t *testing.T) { + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) + + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) + + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) + + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation timer plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) +} + +func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { + // Update the infinite lifetime value to a smaller value so we can test + // that when we receive a PI with such a lifetime value, we do not + // invalidate the prefix. + const testInfiniteLifetimeSeconds = 2 + const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime + defer func() { + header.NDPInfiniteLifetime = saved + }() + + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix in an NDP Prefix Information option (PI) + // with infinite valid lifetime which should not get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + expectPrefixEvent(subnet, true) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with finite lifetime. + // The prefix should get invalidated after 1s. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive an RA with finite lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + expectPrefixEvent(subnet, true) + + // Receive an RA with prefix with an infinite lifetime. + // The prefix should not be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with a prefix with a lifetime value greater than the + // set infinite lifetime value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + } + + // Receive an RA with 0 lifetime. + // The prefix should get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) + expectPrefixEvent(subnet, false) +} + +// TestPrefixDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + + // Receive an RA with 2 more than the max number of discovered on-link + // prefixes. + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} + prefixAddr[7] = byte(i) + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixAddr[:]), + PrefixLen: 64, + } + prefixes[i] = prefix.Subnet() + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = 128 + binary.BigEndian.PutUint32(buf[2:], 10) + copy(buf[14:], prefix.Address) + + optSer[i] = header.NDPPrefixInformation(buf[:]) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < stack.MaxDiscoveredOnLinkPrefixes { + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } else { + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + default: + } + } + } +} + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 5 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set + // to test.evl. + // + + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + default: + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43719085e..e8401c673 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -19,7 +19,6 @@ import ( "sync" "sync/atomic" - "gvisor.dev/gvisor/pkg/ilist" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -34,17 +33,24 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - mu sync.RWMutex spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint stats NICStats + + // ndp is the NDP related state for NIC. + // + // Note, read and write operations on ndp require that the NIC is + // appropriately locked. + ndp ndpState } // NICStats includes transmitted and received stats. @@ -78,17 +84,26 @@ const ( NeverPrimaryEndpoint ) +// newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { - return &NIC{ + // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For + // example, make sure that the link address it provides is a valid + // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + + nic := &NIC{ stack: stack, id: id, name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), + packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -99,7 +114,93 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback Bytes: &tcpip.StatCounter{}, }, }, + ndp: ndpState{ + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + }, + } + nic.ndp.nic = nic + + // Register supported packet endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + } + + return nic +} + +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + // + // Also auto-generate an IPv6 link-local address based on the NIC's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] + if !ok { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + if !n.stack.autoGenIPv6LinkLocal { + return nil } + + l2addr := n.linkEP.LinkAddress() + + // Only attempt to generate the link-local address if we have a + // valid MAC address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } + + addr := header.LinkLocalAddr(l2addr) + + _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint) + + return err } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -129,55 +230,13 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.getKind() == permanent && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { n.mu.RLock() defer n.mu.RUnlock() - list := n.primary[protocol] - if list == nil { - return nil - } - - for e := list.Front(); e != nil; e = e.Next() { - r := e.(*referencedNetworkEndpoint) - // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set. - switch r.ep.ID().LocalAddress { - case header.IPv4Broadcast, header.IPv4Any: - continue - } + for _, r := range n.primary[protocol] { if r.isValidForOutgoing() && r.tryIncRef() { return r } @@ -186,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) } @@ -277,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static) n.mu.Unlock() return ref @@ -287,13 +360,35 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { switch ref.getKind() { - case permanent: + case permanentTentative, permanent: // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent. + // Promote the endpoint to become permanent and respect + // the new peb. if ref.tryIncRef() { ref.setKind(permanent) + + refs := n.primary[ref.protocol] + for i, r := range refs { + if r == ref { + switch peb { + case CanBePrimaryEndpoint: + return ref, nil + case FirstPrimaryEndpoint: + if i == 0 { + return ref, nil + } + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + case NeverPrimaryEndpoint: + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + return ref, nil + } + } + } + + n.insertPrimaryEndpointLocked(ref, peb) + return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -303,10 +398,13 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p n.removeEndpointLocked(ref) } } - return n.addAddressLocked(protocolAddress, peb, permanent) + + return n.addAddressLocked(protocolAddress, peb, permanent, static) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + // Sanity check. id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if _, ok := n.endpoints[id]; ok { @@ -324,12 +422,22 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } + + isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) + + // If the address is an IPv6 address and it is a permanent address, + // mark it as tentative so it goes through the DAD process. + if isIPv6Unicast && kind == permanent { + kind = permanentTentative + } + ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, } // Set up cache if link address resolution exists for this protocol. @@ -339,19 +447,24 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if isIPv6Unicast { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref - l, ok := n.primary[protocolAddress.Protocol] - if !ok { - l = &ilist.List{} - n.primary[protocolAddress.Protocol] = l - } + n.insertPrimaryEndpointLocked(ref, peb) - switch peb { - case CanBePrimaryEndpoint: - l.PushBack(ref) - case FirstPrimaryEndpoint: - l.PushFront(ref) + // If we are adding a tentative IPv6 address, start DAD. + if isIPv6Unicast && kind == permanentTentative { + if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + return nil, err + } } return ref, nil @@ -368,16 +481,20 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and - // prevent the caller from using those. + // Don't include tentative, expired or temporary endpoints to + // avoid confusion and prevent the caller from using those. switch ref.getKind() { - case permanentExpired, temporary: + case permanentTentative, permanentExpired, temporary: + // TODO(b/140898488): Should tentative addresses be + // returned? continue } addrs = append(addrs, tcpip.ProtocolAddress{ @@ -391,6 +508,34 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints + // to avoid confusion and prevent the caller from using + // those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are @@ -435,6 +580,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { return append(sns, n.addressRanges...) } +// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required +// by peb. +// +// n MUST be locked. +func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { + switch peb { + case CanBePrimaryEndpoint: + n.primary[r.protocol] = append(n.primary[r.protocol], r) + case FirstPrimaryEndpoint: + n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + } +} + func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() @@ -452,9 +610,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { } delete(n.endpoints, id) - wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() - if wasInList { - n.primary[r.protocol].Remove(r) + refs := n.primary[r.protocol] + for i, ref := range refs { + if ref == r { + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + break + } } r.ep.Close() @@ -467,13 +628,48 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.getKind() != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok { + return tcpip.ErrBadLocalAddress + } + + kind := r.getKind() + if kind != permanent && kind != permanentTentative { return tcpip.ErrBadLocalAddress } + isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) + + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } + } + r.setKind(permanentExpired) - r.decRefLocked() + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if isIPv6Unicast { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -491,6 +687,18 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + // TODO(b/143102137): When implementing MLD, make sure MLD packets are + // not sent unless a valid link-local address is available for use on n + // as an MLD packet's source address must be a link-local address as + // outlined in RFC 3810 section 5. + id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -518,6 +726,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -534,10 +749,10 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -547,9 +762,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -557,36 +772,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } + // If no local link layer address is provided, assume it was sent + // directly to this NIC. + if local == "" { + local = n.linkEP.LinkAddress() + } + + // Are any packet sockets listening for this network protocol? + n.mu.RLock() + packetEPs := n.packetEPs[protocol] + // Check whether there are packet sockets listening for every protocol. + // If we received a packet with protocol EthernetProtocolAll, then the + // previous for loop will have handled it. + if protocol != header.EthernetProtocolAll { + packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + } + n.mu.RUnlock() + for _, ep := range packetEPs { + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) - - n.stack.AddLinkAddress(n.id, src, remote) - - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -614,31 +832,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } } return } - n.stack.stats.IP.InvalidAddressesReceived.Increment() + // If a packet socket handled the packet, don't treat it as invalid. + if len(packetEPs) == 0 { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + } } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -650,46 +871,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -700,20 +916,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } @@ -728,16 +941,80 @@ func (n *NIC) Stack() *Stack { return n.stack } +// isAddrTentative returns true if addr is tentative on n. +// +// Note that if addr is not associated with n, then this function will return +// false. It will only return true if the address is associated with the NIC +// AND it is tentative. +func (n *NIC) isAddrTentative(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok { + return false + } + + return ref.getKind() == permanentTentative +} + +// dupTentativeAddrDetected attempts to inform n that a tentative addr +// is a duplicate on a link. +// +// dupTentativeAddrDetected will delete the tentative address if it exists. +func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + ref, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok { + return tcpip.ErrBadAddress + } + + if ref.getKind() != permanentTentative { + return tcpip.ErrInvalidEndpointState + } + + return n.removePermanentAddressLocked(addr) +} + +// setNDPConfigs sets the NDP configurations for n. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNDPConfigs(c NDPConfigurations) { + c.validate() + + n.mu.Lock() + n.ndp.configs = c + n.mu.Unlock() +} + +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( + // A permanentTentative endpoint is a permanent address that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses will be of this kind until + // NDP's Duplicate Address Detection has resolved, or be deleted if + // the process results in detecting a duplicate address. + permanentTentative networkEndpointKind = iota + // A permanent endpoint is created by adding a permanent address (vs. a // temporary one) to the NIC. Its reference count is biased by 1 to avoid // removal when no route holds a reference to it. It is removed by explicitly // removing the permanent address from the NIC. - permanent networkEndpointKind = iota + permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -753,8 +1030,50 @@ const ( temporary ) +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return tcpip.ErrNotSupported + } + n.packetEPs[netProto] = append(eps, ep) + + return nil +} + +func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return + } + + for i, epOther := range eps { + if epOther == ep { + n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + return + } + } +} + +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { - ilist.Entry ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -769,6 +1088,10 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { @@ -802,11 +1125,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 67b70b2ee..61fd46d66 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -62,24 +60,64 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // this transport endpoint. It sets pkt.TransportHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw // transport protocol endpoints. RawTransportEndpoints receive the entire -// packet - including the link, network, and transport headers - as delivered -// to netstack. +// packet - including the network and transport headers - as delivered to +// netstack. type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) +} + +// PacketEndpoint is the interface that needs to be implemented by packet +// transport protocol endpoints. These endpoints receive link layer headers in +// addition to whatever they contain (usually network and transport layer +// headers and a payload). +type PacketEndpoint interface { + // HandlePacket is called by the stack when new packets arrive that + // match the endpoint. + // + // Implementers should treat packet as immutable and should copy it + // before before modification. + // + // linkHeader may have a length of 0, in which case the PacketEndpoint + // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of pkt. + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -109,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -127,13 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. + // + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -148,6 +196,19 @@ const ( PacketLoop ) +// NetworkHeaderParams are the header parameters given as input by the +// transport endpoint to the network. +type NetworkHeaderParams struct { + // Protocol refers to the transport protocol number. + Protocol tcpip.TransportProtocolNumber + + // TTL refers to Time To Live field of the IP-header. + TTL uint8 + + // TOS refers to TypeOfService or TrafficClass field of the IP-header. + TOS uint8 +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { @@ -171,12 +232,17 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets to the given destination address and + // protocol. pkts must not be zero length. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -188,8 +254,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. - HandlePacket(r *Route, vv buffer.VectorisedView) + // this network endpoint. It sets pkt.NetworkHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -214,7 +282,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -231,9 +299,15 @@ type NetworkProtocol interface { // packets to the appropriate network endpoint after it has been handled by // the data link layer. type NetworkDispatcher interface { - // DeliverNetworkPacket finds the appropriate network protocol - // endpoint and hands the packet over for further processing. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // DeliverNetworkPacket finds the appropriate network protocol endpoint + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -255,12 +329,18 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityGSO + CapabilityHardwareGSO + + // CapabilitySoftwareGSO indicates the link endpoint supports of sending + // multiple packets using a single call (LinkEndpoint.WritePackets). + CapabilitySoftwareGSO ) // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -282,13 +362,27 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It sets pkt.LinkHeader if a link layer header exists. + // pkt.NetworkHeader and pkt.TransportHeader must have already been + // set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets with the given protocol through the + // given route. pkts must not be zero length. + // + // Right now, WritePackets is used only when the software segmentation + // offload is enabled. If it will be used for something else, it may + // require to change syscall filters. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. The packet + // should already have an ethernet header. + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. @@ -297,6 +391,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -304,13 +407,14 @@ type LinkEndpoint interface { type InjectableLinkEndpoint interface { LinkEndpoint - // Inject injects an inbound packet. - Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // InjectInbound injects an inbound packet. + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) - // WriteRawPacket writes a fully formed outbound packet directly to the link. + // InjectOutbound writes a fully formed outbound packet directly to the + // link. // // dest is used by endpoints with multiple raw destinations. - WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -339,10 +443,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -353,79 +457,22 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) -} - -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - -// UnassociatedEndpointFactory produces endpoints for writing packets not -// associated with a particular transport protocol. Such endpoints can be used -// to write arbitrary packets that include the IP header. -type UnassociatedEndpointFactory interface { - NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) -} - -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() +// RawFactory produces endpoints for writing various types of raw packets. +type RawFactory interface { + // NewUnassociatedEndpoint produces endpoints for writing packets not + // associated with a particular transport protocol. Such endpoints can + // be used to write arbitrary packets that include the network header. + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) - return linkEndpoints[id] + // NewPacketEndpoint produces endpoints for reading and writing packets + // that include network and (when cooked is false) link layer headers. + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } // GSOType is the type of GSO segments. @@ -436,8 +483,14 @@ type GSOType int // Types of gso segments. const ( GSONone GSOType = iota + + // Hardware GSO types: GSOTCPv4 GSOTCPv6 + + // GSOSW is used for software GSO segments which have to be sent by + // endpoint.WritePackets. + GSOSW ) // GSO contains generic segmentation offload properties. @@ -465,3 +518,7 @@ type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 } + +// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. +// This isn't a hard limit, because it is never set into packet headers. +const SoftwareGSOMaxSize = (1 << 16) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..34307ae07 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -47,8 +46,8 @@ type Route struct { // starts. ref *referencedNetworkEndpoint - // loop controls where WritePacket should send packets. - loop PacketLooping + // Loop controls where WritePacket should send packets. + Loop PacketLooping } // makeRoute initializes a new route. It takes ownership of the provided @@ -59,6 +58,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -67,7 +68,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip LocalLinkAddress: localLinkAddr, RemoteAddress: remoteAddr, ref: ref, - loop: loop, + Loop: loop, } } @@ -152,34 +153,54 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) + err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } return err } +// WritePackets writes the set of packets through the given route. +func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { + if !r.ref.isValidForOutgoing() { + return 0, tcpip.ErrInvalidEndpointState + } + + n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + } + r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + payloadSize := 0 + for i := 0; i < n; i++ { + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) + payloadSize += pkts[i].DataSize + } + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) return nil } @@ -208,10 +229,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6beca6ae8..0e88643a4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,18 +17,16 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( + "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -46,11 +44,14 @@ const ( resolutionTimeout = 1 * time.Second // resolutionAttempts is set to the same ARP retries used in Linux. resolutionAttempts = 3 + + // DefaultTOS is the default type of service value for network endpoints. + DefaultTOS = 0 ) type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -351,7 +359,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver - unassociatedFactory UnassociatedEndpointFactory + // rawFactory creates raw endpoints. If nil, raw endpoints are + // disabled. It is set during Stack creation and is immutable. + rawFactory RawFactory demux *transportDemuxer @@ -359,13 +369,10 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -394,10 +401,42 @@ type Stack struct { // icmpRateLimiter is a global rate limiter for all ICMP messages generated // by the stack. icmpRateLimiter *ICMPRateLimiter + + // seed is a one-time random value initialized at stack startup + // and is used to seed the TCP port picking on active connections + // + // TODO(gvisor.dev/issue/940): S/R this field. + seed uint32 + + // ndpConfigs is the default NDP configurations used by interfaces. + ndpConfigs NDPConfigurations + + // autoGenIPv6LinkLocal determines whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -411,44 +450,109 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + + // NDPConfigs is the default NDP configurations used by interfaces. + // + // By default, NDPConfigs will have a zero value for its + // DupAddrDetectTransmits field, implying that DAD will not be performed + // before assigning an address to a NIC. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determins whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // Note, setting this to true does not mean that a link-local address + // will be assigned right away, or at all. If Duplicate Address + // Detection is enabled, an address will only be assigned if it + // successfully resolves. If it fails, no further attempt will be made + // to auto-generate an IPv6 link-local address. + // + // The generated link-local address will follow RFC 4291 Appendix A + // guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // RawFactory produces raw endpoints. Raw endpoints are enabled only if + // this is non-nil. + RawFactory RawFactory } +// TransportEndpointInfo holds useful information about a transport endpoint +// which can be queried by monitoring tools. +// +// +stateify savable +type TransportEndpointInfo struct { + // The following fields are initialized at creation time and are + // immutable. + + NetProto tcpip.NetworkProtocolNumber + TransProto tcpip.TransportProtocolNumber + + // The following fields are protected by endpoint mu. + + ID TransportEndpointID + // BindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + BindNICID tcpip.NICID + BindAddr tcpip.Address + // RegisterNICID is the default NICID registered as a side-effect of + // connect or datagram write. + RegisterNICID tcpip.NICID +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*TransportEndpointInfo) IsEndpointInfo() {} + // New allocates a new networking stack with only the requested networking and // transport protocols configured with default options. // +// Note, NDPConfigurations will be fixed before being used by the Stack. That +// is, if an invalid value was provided, it will be reset to the default value. +// // Protocol options can be changed by calling the // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + + // Make sure opts.NDPConfigs contains valid values only. + opts.NDPConfigs.validate() + s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - raw: opts.Raw, - icmpRateLimiter: NewICMPRateLimiter(), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + ndpConfigs: opts.NDPConfigs, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, + ndpDisp: opts.NDPDisp, } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -456,18 +560,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for raw endpoints, if present. + s.rawFactory = opts.RawFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -475,6 +575,11 @@ func New(network []string, transport []string, opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -536,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -602,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.rawFactory == nil { return nil, tcpip.ErrNotPermitted } if !associated { - return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue) } t, ok := s.transportProtocols[transport] @@ -618,14 +723,19 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return t.proto.NewRawEndpoint(s, network, waiterQueue) } -// createNIC creates a NIC with the provided id and link-layer endpoint, and -// optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint +// NewPacketEndpoint creates a new packet endpoint listening for the given +// netProto. +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if s.rawFactory == nil { + return nil, tcpip.ErrNotPermitted } + return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) +} + +// createNIC creates a NIC with the provided id and link-layer endpoint, and +// optionally enable it. +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -638,40 +748,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -685,9 +795,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. @@ -745,7 +853,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -852,19 +960,37 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return tcpip.ErrUnknownNICID } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() + } + return nics +} + +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -891,7 +1017,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { @@ -929,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -994,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1031,73 +1157,52 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} - s.mu.RLock() - defer s.mu.RUnlock() +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on @@ -1108,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1122,6 +1290,109 @@ func (s *Stack) Resume() { } } +// RegisterPacketEndpoint registers ep with the stack, causing it to receive +// all traffic of the specified netProto on the given NIC. If nicID is 0, it +// receives traffic from every NIC. +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + // If no NIC is specified, capture on all devices. + if nicID == 0 { + // Register with each NIC. + for _, nic := range s.nics { + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + s.unregisterPacketEndpointLocked(0, netProto, ep) + return err + } + } + return nil + } + + // Capture on a specific device. + nic, ok := s.nics[nicID] + if !ok { + return tcpip.ErrUnknownNICID + } + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + return err + } + + return nil +} + +// UnregisterPacketEndpoint unregisters ep for packets of the specified +// netProto from the specified NIC. If nicID is 0, ep is unregistered from all +// NICs. +func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + s.mu.Lock() + defer s.mu.Unlock() + s.unregisterPacketEndpointLocked(nicID, netProto, ep) +} + +func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + // If no NIC is specified, unregister on all devices. + if nicID == 0 { + // Unregister with each NIC. + for _, nic := range s.nics { + nic.unregisterPacketEndpoint(netProto, ep) + } + return + } + + // Unregister in a single device. + nic, ok := s.nics[nicID] + if !ok { + return + } + nic.unregisterPacketEndpoint(netProto, ep) +} + +// WritePacket writes data directly to the specified NIC. It adds an ethernet +// header based on the arguments. +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + // Add our own fake ethernet header. + ethFields := header.EthernetFields{ + SrcAddr: nic.linkEP.LinkAddress(), + DstAddr: dst, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) + + if err := nic.linkEP.WriteRawPacket(vv); err != nil { + return err + } + + return nil +} + +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + if err := nic.linkEP.WriteRawPacket(payload); err != nil { + return err + } + + return nil +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1241,3 +1512,85 @@ func (s *Stack) SetICMPBurst(burst int) { func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } + +// IsAddrTentative returns true if addr is tentative on the NIC with ID id. +// +// Note that if addr is not associated with a NIC with id ID, then this +// function will return false. It will only return true if the address is +// associated with the NIC AND it is tentative. +func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return false, tcpip.ErrUnknownNICID + } + + return nic.isAddrTentative(addr), nil +} + +// DupTentativeAddrDetected attempts to inform the NIC with ID id that a +// tentative addr on it is a duplicate on a link. +func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.dupTentativeAddrDetected(addr) +} + +// SetNDPConfigurations sets the per-interface NDP configurations on the NIC +// with ID id to c. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.setNDPConfigs(c) + + return nil +} + +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) + + return nil +} + +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) Seed() uint32 { + return s.seed +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c6a8160af..8fc034ca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,11 +24,14 @@ import ( "sort" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -55,20 +58,20 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -83,32 +86,32 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,35 +119,41 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] - b[2] = byte(protocol) + b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) } if loop&stack.PacketOut == 0 { return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -189,14 +198,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -222,12 +231,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +260,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +272,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +284,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +295,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +307,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -304,62 +329,67 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +399,11 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,17 +420,19 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +444,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +476,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -476,10 +510,12 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +527,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -554,10 +590,12 @@ func TestAddressRemoval(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +616,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -599,11 +637,13 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +666,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -644,11 +684,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -681,17 +721,19 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) for _, promiscuous := range []bool{true, false} { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -708,13 +750,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -722,55 +764,55 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + verifyAddress(t, s, nicID, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + verifyAddress(t, s, nicID, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,64 +821,64 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + verifyAddress(t, s, nicID, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + verifyAddress(t, s, nicID, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -844,10 +886,12 @@ func TestEndpointExpiration(t *testing.T) { } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +911,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +930,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -894,10 +938,12 @@ func TestSpoofingWithAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x02") dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -930,14 +976,14 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -945,23 +991,25 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +1028,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -992,49 +1040,138 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -1074,10 +1211,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1130,10 +1269,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { // Add a range of addresses, then check that a packet is delivered. func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1159,7 +1300,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { @@ -1196,10 +1337,12 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1234,10 +1377,12 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { // Set a range of addresses, then send a packet to a destination outside the // range and then check it doesn't get delivered. func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,11 +1407,14 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -1319,9 +1467,11 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S } func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1360,9 +1510,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1400,20 +1552,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1425,9 +1577,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1452,19 +1606,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1479,8 +1639,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1499,10 +1661,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + const nicID = 1 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1510,7 +1674,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1519,15 +1683,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + const nicID = 1 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1544,22 +1710,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + const nicID = 1 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1570,7 +1738,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1580,15 +1748,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + const nicID = 1 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1607,7 +1777,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1615,14 +1785,16 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1639,7 +1811,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1653,9 +1827,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1666,19 +1840,21 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1697,10 +1873,12 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } @@ -1715,8 +1893,296 @@ func TestNICForwarding(t *testing.T) { } } -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) +// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses +// (or lack there-of if disabled (default)). Note, DAD will be disabled in +// these tests. +func TestNICAutoGenAddr(t *testing.T) { + tests := []struct { + name string + autoGen bool + linkAddr tcpip.LinkAddress + shouldGen bool + }{ + { + "Disabled", + false, + linkAddr1, + false, + }, + { + "Enabled", + true, + linkAddr1, + true, + }, + { + "Nil MAC", + true, + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty MAC", + true, + tcpip.LinkAddress(""), + false, + }, + { + "Invalid MAC", + true, + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Multicast MAC", + true, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Unspecified MAC", + true, + tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + + if test.shouldGen { + // Should have auto-generated an address and + // resolved immediately (DAD is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + +// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected +// when an address's kind gets "promoted" to permanent from permanentExpired. +func TestNewPEBOnPromotionToPermanent(t *testing.T) { + pebs := []stack.PrimaryEndpointBehavior{ + stack.NeverPrimaryEndpoint, + stack.CanBePrimaryEndpoint, + stack.FirstPrimaryEndpoint, + } + + for _, pi := range pebs { + for _, ps := range pebs { + t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + // Add a permanent address with initial + // PrimaryEndpointBehavior (peb), pi. If pi is + // NeverPrimaryEndpoint, the address should not + // be returned by a call to GetMainNICAddress; + // else, it should. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + addr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if pi == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatalf("NewSubnet failed:", err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // Take a route through the address so its ref + // count gets incremented and does not actually + // get deleted when RemoveAddress is called + // below. This is because we want to test that a + // new peb is respected when an address gets + // "promoted" to permanent from a + // permanentExpired kind. + r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + defer r.Release() + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + + // + // At this point, the address should still be + // known by the NIC, but have its + // kind = permanentExpired. + // + + // Add some other address with peb set to + // FirstPrimaryEndpoint. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + + } + + // Add back the address we removed earlier and + // make sure the new peb was respected. + // (The address should just be promoted now). + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + var primaryAddrs []tcpip.Address + for _, pa := range s.NICInfo()[1].ProtocolAddresses { + primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) + } + var expectedList []tcpip.Address + switch ps { + case stack.FirstPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x01", + "\x03", + } + case stack.CanBePrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + "\x01", + } + case stack.NeverPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + } + } + if !cmp.Equal(primaryAddrs, expectedList) { + t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) + } + + // Once we remove the other address, if the new + // peb, ps, was NeverPrimaryEndpoint, no address + // should be returned by a call to + // GetMainNICAddress; else, our original address + // should be returned. + if err := s.RemoveAddress(1, "\x03"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + addr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if ps == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else { + if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + } + }) + } + } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..67c21be42 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,10 +17,10 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -43,19 +43,122 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } } - delete(eps.endpoints, id) + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if isMulticastOrBroadcast(id.LocalAddress) { + mpep.handlePacketAll(r, id, pkt) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -75,7 +178,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +188,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +200,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool +} + +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps } // reciprocalScale scales a value into range [0, n). @@ -117,9 +234,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +246,77 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, pkt) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt.Clone()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } } -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +341,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -257,33 +389,49 @@ var loopbackSubnet = func() tcpip.Subnet { }() // deliverPacket attempts to find one or more matching transport endpoints, and -// then, if matches are found, delivers the packet to them. Returns true if it -// found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +// then, if matches are found, delivers the packet to them. Returns true if +// the packet no longer needs to be handled. +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a UDP broadcast or multicast, then find all matching + // transport endpoints. If the packet is a TCP packet with a non-unicast + // source or destination address, then do nothing further and instruct + // the caller to do the same. + var destEps []*endpointsByNic + switch protocol { + case header.UDPProtocolNumber: + if isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, id) + break + } + + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } + + case header.TCPProtocolNumber: + if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single + // source and a single destination; the addresses must + // be unicast. + eps.mu.RUnlock() + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true + } + + fallthrough + + default: + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) } - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { - destEps = append(destEps, ep) } eps.mu.RUnlock() @@ -297,17 +445,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -321,7 +471,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -331,7 +481,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -339,7 +489,7 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -348,15 +498,16 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -364,7 +515,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -372,15 +523,54 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) + } + return matchedEPs +} + +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } } + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -391,7 +581,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return nil + return tcpip.ErrNotSupported } eps.mu.Lock() @@ -418,3 +608,11 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} + +func isUnicast(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..3b28b06d0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,354 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicID := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ca185279e..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,19 +38,27 @@ const ( // Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't // use it. type fakeTransportEndpoint struct { - id stack.TransportEndpointID + stack.TransportEndpointInfo stack *stack.Stack - netProto tcpip.NetworkProtocolNumber proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint } -func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { + return &f.TransportEndpointInfo +} + +func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { + return nil +} + +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -65,17 +73,20 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -91,6 +102,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -121,8 +137,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() // Try to register so that we can start receiving packets. - f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + f.ID.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -132,6 +148,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -163,7 +183,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } @@ -179,14 +200,16 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - id: id, - stack: f.stack, - netProto: f.netProto, + stack: f.stack, + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, proto: f.proto, peerAddr: r.RemoteAddress, route: r.Clone(), @@ -194,7 +217,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -203,15 +226,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -236,7 +259,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -251,7 +274,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -277,10 +300,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -315,7 +345,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -324,7 +356,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -333,16 +367,21 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -383,7 +422,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -392,7 +433,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -401,16 +444,21 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -452,7 +500,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -493,20 +544,23 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +599,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,21 +615,15 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 418e771d2..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "<nil>" + } return e.msg } @@ -228,6 +231,13 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -252,7 +262,7 @@ type FullAddress struct { // This may not be used by all endpoint types. NIC NICID - // Addr is the network address. + // Addr is the network or link layer address. Addr Address // Port is the transport port. @@ -261,31 +271,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -295,7 +308,7 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packed used to create + // Timestamp is the time (in ns) that the last packet used to create // the read data was received. Timestamp int64 @@ -304,6 +317,18 @@ type ControlMessages struct { // Inq is the number of bytes ready to be received. Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS int8 + + // HasTClass indicates whether Tclass is valid/set. + HasTClass bool + + // Tclass is the IPv6 traffic class of the associated packet. + TClass int32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -338,7 +363,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -398,6 +423,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -419,6 +448,26 @@ type Endpoint interface { // IPTables returns the iptables for this endpoint's stack. IPTables() (iptables.IPTables, error) + + // Info returns a copy to the transport endpoint info. + Info() EndpointInfo + + // Stats returns a reference to the endpoint stats. + Stats() EndpointStats +} + +// EndpointInfo is the interface implemented by each endpoint info struct. +type EndpointInfo interface { + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo + // marker interface. + IsEndpointInfo() +} + +// EndpointStats is the interface implemented by each endpoint stats struct. +type EndpointStats interface { + // IsEndpointStats is an empty method to implement the tcpip.EndpointStats + // marker interface. + IsEndpointStats() } // WriteOptions contains options for Endpoint.Write. @@ -432,16 +481,38 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -450,27 +521,10 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be -// sent out immediately by the transport protocol. For TCP, it determines if the -// Nagle algorithm is on or off. -type DelayOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -483,6 +537,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int @@ -518,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string @@ -534,6 +597,22 @@ type ModerateReceiveBufferOption bool // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. type MaxSegOption int +// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop +// limit value for unicast messages. The default is protocol specific. +// +// A zero value indicates the default. +type TTLOption uint8 + +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -575,6 +654,18 @@ type OutOfBandInlineOption int // datagram sockets are allowed to send packets to a broadcast address. type BroadcastOption int +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + +// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv4 packets from the endpoint. +type IPv4TOSOption uint8 + +// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv6 packets from the endpoint. +type IPv6TrafficClassOption uint8 + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -600,9 +691,6 @@ func (r Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -619,6 +707,11 @@ func (s *StatCounter) Increment() { s.IncrementBy(1) } +// Decrement minuses one to the counter. +func (s *StatCounter) Decrement() { + s.IncrementBy(^uint64(0)) +} + // Value returns the current value of the counter. func (s *StatCounter) Value() uint64 { return atomic.LoadUint64(&s.count) @@ -807,6 +900,14 @@ type IPStats struct { // OutgoingPacketErrors is the total number of IP packets which failed // to write to a link-layer endpoint. OutgoingPacketErrors *StatCounter + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived *StatCounter + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived *StatCounter } // TCPStats collects TCP-specific stats. @@ -819,6 +920,23 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // CurrentEstablished is the number of TCP connections for which the + // current state is either ESTABLISHED or CLOSE-WAIT. + CurrentEstablished *StatCounter + + // EstablishedResets is the number of times TCP connections have made + // a direct transition to the CLOSED state from either the + // ESTABLISHED state or the CLOSE-WAIT state. + EstablishedResets *StatCounter + + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter @@ -853,6 +971,9 @@ type TCPStats struct { // SegmentsSent is the number of TCP segments sent. SegmentsSent *StatCounter + // SegmentSendErrors is the number of TCP segments failed to be sent. + SegmentSendErrors *StatCounter + // ResetsSent is the number of TCP resets sent. ResetsSent *StatCounter @@ -905,6 +1026,9 @@ type UDPStats struct { // PacketsSent is the number of UDP datagrams sent via sendUDP. PacketsSent *StatCounter + + // PacketSendErrors is the number of datagrams failed to be sent. + PacketSendErrors *StatCounter } // Stats holds statistics about the networking stack. @@ -915,7 +1039,7 @@ type Stats struct { // stack that were for an unknown or unsupported protocol. UnknownProtocolRcvdPackets *StatCounter - // MalformedRcvPackets is the number of packets received by the stack + // MalformedRcvdPackets is the number of packets received by the stack // that were deemed malformed. MalformedRcvdPackets *StatCounter @@ -935,18 +1059,95 @@ type Stats struct { UDP UDPStats } +// ReceiveErrors collects packet receive errors within transport endpoint. +type ReceiveErrors struct { + // ReceiveBufferOverflow is the number of received packets dropped + // due to the receive buffer being full. + ReceiveBufferOverflow StatCounter + + // MalformedPacketsReceived is the number of incoming packets + // dropped due to the packet header being in a malformed state. + MalformedPacketsReceived StatCounter + + // ClosedReceiver is the number of received packets dropped because + // of receiving endpoint state being closed. + ClosedReceiver StatCounter +} + +// SendErrors collects packet send errors within the transport layer for +// an endpoint. +type SendErrors struct { + // SendToNetworkFailed is the number of packets failed to be written to + // the network endpoint. + SendToNetworkFailed StatCounter + + // NoRoute is the number of times we failed to resolve IP route. + NoRoute StatCounter + + // NoLinkAddr is the number of times we failed to resolve ARP. + NoLinkAddr StatCounter +} + +// ReadErrors collects segment read errors from an endpoint read call. +type ReadErrors struct { + // ReadClosed is the number of received packet drops because the endpoint + // was shutdown for read. + ReadClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter +} + +// WriteErrors collects packet write errors from an endpoint write call. +type WriteErrors struct { + // WriteClosed is the number of packet drops because the endpoint + // was shutdown for write. + WriteClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter + + // InvalidArgs is the number of times invalid input arguments were + // provided for endpoint Write call. + InvalidArgs StatCounter +} + +// TransportEndpointStats collects statistics about the endpoint. +type TransportEndpointStats struct { + // PacketsReceived is the number of successful packet receives. + PacketsReceived StatCounter + + // PacketsSent is the number of successful packet sends. + PacketsSent StatCounter + + // ReceiveErrors collects packet receive errors within transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects packet read errors from an endpoint read call. + ReadErrors ReadErrors + + // SendErrors collects packet send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects packet write errors from an endpoint write call. + WriteErrors WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*TransportEndpointStats) IsEndpointStats() {} + func fillIn(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - if s := v.Addr().Interface().(**StatCounter); *s == nil { - *s = &StatCounter{} + if s, ok := v.Addr().Interface().(**StatCounter); ok { + if *s == nil { + *s = new(StatCounter) } - case reflect.Struct: + } else { fillIn(v) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -957,6 +1158,26 @@ func (s Stats) FillIn() Stats { return s } +// Clone returns a copy of the TransportEndpointStats by atomically reading +// each field. +func (src *TransportEndpointStats) Clone() TransportEndpointStats { + var dst TransportEndpointStats + clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem()) + return dst +} + +func clone(dst reflect.Value, src reflect.Value) { + for i := 0; i < dst.NumField(); i++ { + d := dst.Field(i) + s := src.Field(i) + if c, ok := s.Addr().Interface().(*StatCounter); ok { + d.Addr().Interface().(*StatCounter).IncrementBy(c.Value()) + } else { + clone(d, s) + } + } +} + // String implements the fmt.Stringer interface. func (a Address) String() string { switch len(a) { @@ -1082,6 +1303,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { @@ -1102,8 +1364,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..48764b978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d78a162b8..d8c5b5058 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "icmp_packet_list", out = "icmp_packet_list.go", @@ -38,11 +38,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..9c40931b5 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -52,12 +49,13 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -73,30 +71,32 @@ type endpoint struct { sndBufSize int // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags - id stack.TransportEndpointID state endpointState - // bindNICID and bindAddr are set via calls to Bind(). They are used to - // reject attempts to send data or connect via a different NIC or - // address - bindNICID tcpip.NICID - bindAddr tcpip.Address - // regNICID is the default NIC to be used when callers don't specify a - // NIC. - regNICID tcpip.NICID - route stack.Route `state:"manual"` -} - -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + route stack.Route `state:"manual"` + ttl uint8 + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -104,7 +104,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -143,6 +143,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -204,7 +205,30 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -254,13 +278,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicID = e.BindNICID } toCopy := *to @@ -271,7 +295,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -289,17 +313,17 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v) + err = send4(route, e.ID.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v) + err = send6(route, e.ID.LocalPort, v, e.ttl) } if err != nil { @@ -314,8 +338,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(o) + e.mu.Unlock() + } + + return nil +} + +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return nil } @@ -331,6 +367,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -341,28 +389,22 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() + case *tcpip.KeepaliveEnabledOption: + *o = 0 return nil - case *tcpip.ReceiveBufferSizeOption: + case *tcpip.TTLOption: e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + *o = tcpip.TTLOption(e.ttl) e.rcvMu.Unlock() return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -384,10 +426,17 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + }) } -func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -404,21 +453,28 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - icmpv6.SetChecksum(0) - icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { return 0, tcpip.ErrNoRoute } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -435,20 +491,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { case stateBound, stateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -459,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -476,14 +532,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } - e.id = id + e.ID = id e.route = r.Clone() - e.regNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -534,18 +590,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -592,8 +648,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = addr.NIC + e.ID = id + e.RegisterNICID = addr.NIC // Mark endpoint as bound. e.state = stateBound @@ -616,8 +672,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return err } - e.bindNICID = addr.NIC - e.bindAddr = addr.Addr + e.BindNICID = addr.NIC + e.BindAddr = addr.Addr return nil } @@ -628,9 +684,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -644,9 +700,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -670,19 +726,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } } @@ -690,31 +748,39 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() - + e.stats.PacketsReceived.Increment() // Notify any waiters that there's data to be read now. if wasEmpty { e.waiterQueue.Notify(waiter.EventIn) @@ -722,7 +788,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -730,3 +796,20 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index c587b96b6..9d263c0ec 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { panic(err) } - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + e.ID.LocalAddress = e.route.LocalAddress + } else if len(e.ID.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 1eb790932..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -14,10 +14,9 @@ // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport // protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and -// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or -// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when -// calling stack.New(). Then endpoints can be created by passing +// must be added to the project, and activated on the stack by passing +// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport +// protocols when calling stack.New(). Then endpoints can be created by passing // icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number // when calling Stack.NewEndpoint(). package icmp @@ -34,15 +33,9 @@ import ( ) const ( - // ProtocolName4 is the string representation of the icmp protocol name. - ProtocolName4 = "icmp4" - // ProtocolNumber4 is the ICMP protocol number. ProtocolNumber4 = header.ICMPv4ProtocolNumber - // ProtocolName6 is the string representation of the icmp protocol name. - ProtocolName6 = "icmp6" - // ProtocolNumber6 is the IPv6-ICMP protocol number. ProtocolNumber6 = header.ICMPv6ProtocolNumber ) @@ -111,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { - return &protocol{ProtocolNumber4} - }) +// NewProtocol4 returns an ICMPv4 transport protocol. +func NewProtocol4() stack.TransportProtocol { + return &protocol{ProtocolNumber4} +} - stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { - return &protocol{ProtocolNumber6} - }) +// NewProtocol6 returns an ICMPv6 transport protocol. +func NewProtocol6() stack.TransportProtocol { + return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD new file mode 100644 index 000000000..44b58ff6b --- /dev/null +++ b/pkg/tcpip/transport/packet/BUILD @@ -0,0 +1,38 @@ +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "packet", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "packet", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "packet_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/iptables", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go new file mode 100644 index 000000000..0010b5e5f --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packet provides the implementation of packet sockets (see +// packet(7)). Packet sockets allow applications to: +// +// * manually write and inspect link, network, and transport headers +// * receive all traffic of a given network protocol, or all protocols +// +// Packet sockets are similar to raw sockets, but provide even more power to +// users, letting them effectively talk directly to the network device. +// +// Packet sockets skip the input and output iptables chains. +package packet + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal +// to have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + cooked bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +// NewEndpoint returns a new packet endpoint. +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + }, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { + return nil, err + } + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (ep *endpoint) ModerateRecvBuf(copied int) {} + +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + ep.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(b/129292371): Implement. + return 0, nil, tcpip.ErrInvalidOptionValue +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be +// disconnected, and this function always returns tpcip.ErrNotSupported. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be +// connected, and this function always returnes tcpip.ErrNotSupported. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used +// with Shutdown, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with +// Listen, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with +// Accept, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + // TODO(gvisor.dev/issue/173): Add Bind support. + + // "By default, all packets of the specified protocol type are passed + // to a packet socket. To get packets only from a specific interface + // use bind(2) specifying an address in a struct sockaddr_ll to bind + // the packet socket to an interface. Fields used for binding are + // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." + // - packet(7). + + return tcpip.ErrNotSupported +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be +// used with SetSockOpt, and this function always returns +// tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// HandlePacket implements stack.PacketEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + var packet packet + // TODO(b/129292371): Return network protocol. + if len(pkt.LinkHeader) > 0 { + // Get info directly from the ethernet header. + hdr := header.Ethernet(pkt.LinkHeader) + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(hdr.SourceAddress()), + } + } else { + // Guess the would-be ethernet header. + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(localAddr), + } + } + + if ep.cooked { + // Cooked packets can simply be queued. + packet.data = pkt.Data + } else { + // Raw packets need their ethernet headers prepended before + // queueing. + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + } + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(&packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + ep.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (ep *endpoint) Info() tcpip.EndpointInfo { + ep.mu.RLock() + // Make a copy of the endpoint info. + ret := ep.TransportEndpointInfo + ep.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (ep *endpoint) Stats() tcpip.EndpointStats { + return &ep.stats +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go new file mode 100644 index 000000000..9b88f17e4 --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -0,0 +1,72 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. + if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 7241f6c19..00991ac8e 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,17 +1,17 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( - name = "packet_list", - out = "packet_list.go", + name = "raw_packet_list", + out = "raw_packet_list.go", package = "raw", - prefix = "packet", + prefix = "rawPacket", template = "//pkg/ilist:generic_list", types = { - "Element": "*packet", - "Linker": "*packet", + "Element": "*rawPacket", + "Linker": "*rawPacket", }, ) @@ -20,8 +20,8 @@ go_library( srcs = [ "endpoint.go", "endpoint_state.go", - "packet_list.go", "protocol.go", + "raw_packet_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], @@ -34,14 +34,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/packet", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -17,8 +17,7 @@ // // * manually write and inspect transport layer headers and payloads // * receive all traffic of a given transport protocol (e.g. ICMP or UDP) -// * optionally write and inspect network layer and link layer headers for -// packets +// * optionally write and inspect network layer headers of packets // // Raw sockets don't have any notion of ports, and incoming packets are // demultiplexed solely by protocol number. Thus, a raw UDP endpoint will @@ -38,15 +37,11 @@ import ( ) // +stateify savable -type packet struct { - packetEntry +type rawPacket struct { + rawPacketEntry // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -62,18 +57,17 @@ type packet struct { // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + rcvList rawPacketList rcvBufSizeMax int `state:".(int)"` rcvBufSize int rcvClosed bool @@ -84,35 +78,28 @@ type endpoint struct { closed bool connected bool bound bool - // registeredNIC is the NIC to which th endpoint is explicitly - // registered. Is set when Connect or Bind are used to specify a NIC. - registeredNIC tcpip.NICID - // boundNIC and boundAddr are set on calls to Bind(). When callers - // attempt actions that would invalidate the binding data (e.g. sending - // data via a NIC other than boundNIC), the endpoint will return an - // error. - boundNIC tcpip.NICID - boundAddr tcpip.Address // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route stack.Route `state:"manual"` + stats tcpip.TransportEndpointStats `state:"nosave"` } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } - ep := &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, @@ -123,115 +110,139 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - ep.rcvBufSizeMax = 0 - ep.waiterQueue = nil - return ep, nil + e.rcvBufSizeMax = 0 + e.waiterQueue = nil + return e, nil } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { return nil, err } - return ep, nil + return e, nil } // Close implements tcpip.Endpoint.Close. -func (ep *endpoint) Close() { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed || !ep.associated { + if e.closed || !e.associated { return } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) - ep.rcvMu.Lock() - defer ep.rcvMu.Unlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() // Clear the receive list. - ep.rcvClosed = true - ep.rcvBufSize = 0 - for !ep.rcvList.Empty() { - ep.rcvList.Remove(ep.rcvList.Front()) + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + e.rcvList.Remove(e.rcvList.Front()) } - if ep.connected { - ep.route.Release() - ep.connected = false + if e.connected { + e.route.Release() + e.connected = false } - ep.closed = true + e.closed = true - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { - return ep.stack.IPTables(), nil +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil } // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !ep.associated { +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !e.associated { return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue } - ep.rcvMu.Lock() + e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. - if ep.rcvList.Empty() { + if e.rcvList.Empty() { err := tcpip.ErrWouldBlock - if ep.rcvClosed { + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return buffer.View{}, tcpip.ControlMessages{}, err } - packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + pkt := e.rcvList.Front() + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() - ep.rcvMu.Unlock() + e.rcvMu.Unlock() if addr != nil { - *addr = packet.senderAddr + *addr = pkt.senderAddr } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil + return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - ep.mu.RLock() + e.mu.RLock() - if ep.closed { - ep.mu.RUnlock() + if e.closed { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !ep.associated { + if !e.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { - ep.mu.RUnlock() + if !ip.IsValid(len(payloadBytes)) { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -252,66 +263,66 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if !ep.connected { - ep.mu.RUnlock() + if !e.connected { + e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if ep.route.IsResolutionRequired() { - savedRoute := &ep.route + if e.route.IsResolutionRequired() { + savedRoute := &e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. - ep.mu.RUnlock() - ep.mu.Lock() + e.mu.RUnlock() + e.mu.Lock() // Make sure that the route didn't change during the // time we didn't hold the lock. - if !ep.connected || savedRoute != &ep.route { - ep.mu.Unlock() + if !e.connected || savedRoute != &e.route { + e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payloadBytes, savedRoute) - ep.mu.Unlock() + n, ch, err := e.finishWrite(payloadBytes, savedRoute) + e.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payloadBytes, &ep.route) - ep.mu.RUnlock() + n, ch, err := e.finishWrite(payloadBytes, &e.route) + e.mu.RUnlock() return n, ch, err } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC - if ep.bound && nic != 0 && nic != ep.boundNIC { - ep.mu.RUnlock() + if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } // We don't support IPv6 yet, so this has to be an IPv4 address. if len(opts.To.Addr) != header.IPv4AddressSize { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - // Find the route to the destination. If boundAddress is 0, + // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. - route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } - n, ch, err := ep.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, &route) route.Release() - ep.mu.RUnlock() + e.mu.RUnlock() return n, ch, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -324,16 +335,21 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch ep.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - if !ep.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + if !e.associated { + if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -345,7 +361,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -355,11 +371,11 @@ func (*endpoint) Disconnect() *tcpip.Error { } // Connect implements tcpip.Endpoint.Connect. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed { + if e.closed { return tcpip.ErrInvalidEndpointState } @@ -369,15 +385,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } nic := addr.NIC - if ep.bound { - if ep.boundNIC == 0 { + if e.bound { + if e.BindNICID == 0 { // If we're bound, but not to a specific NIC, the NIC // in addr will be used. Nothing to do here. } else if addr.NIC == 0 { // If we're bound to a specific NIC, but addr doesn't // specify a NIC, use the bound NIC. - nic = ep.boundNIC - } else if addr.NIC != ep.boundNIC { + nic = e.BindNICID + } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. return tcpip.ErrInvalidEndpointState @@ -385,53 +401,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the destination. - route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) if err != nil { return err } defer route.Release() - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = nic + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = nic } // Save the route we've connected via. - ep.route = route.Clone() - ep.connected = true + e.route = route.Clone() + e.connected = true return nil } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if !ep.connected { + if !e.connected { return tcpip.ErrNotConnected } return nil } // Listen implements tcpip.Endpoint.Listen. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() // Callers must provide an IPv4 address or no network address (for // binding to a NIC, but not an address). @@ -440,94 +456,100 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = addr.NIC + e.BindNICID = addr.NIC } - ep.boundAddr = addr.Addr - ep.bound = true + e.BindAddr = addr.Addr + e.bound = true return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } // Readiness implements tcpip.Endpoint.Readiness. -func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // The endpoint is always writable. result := waiter.EventOut & mask // Determine whether the endpoint is readable. if (mask & waiter.EventIn) != 0 { - ep.rcvMu.Lock() - if !ep.rcvList.Empty() || ep.rcvClosed { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { result |= waiter.EventIn } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() } return result } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 - ep.rcvMu.Lock() - if !ep.rcvList.Empty() { - p := ep.rcvList.Front() + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() v = p.data.Size() } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -538,63 +560,89 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { - ep.rcvMu.Lock() +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { - ep.stack.Stats().DroppedPackets.Increment() - ep.rcvMu.Unlock() + if e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } - if ep.bound { + if e.bound { // If bound to a NIC, only accept data for that NIC. - if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { - ep.rcvMu.Unlock() + if e.BindNICID != 0 && e.BindNICID != route.NICID() { + e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + e.rcvMu.Unlock() return } } // If connected, only accept packets from the remote address we // connected to. - if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.connected && e.route.RemoteAddress != route.RemoteAddress { + e.rcvMu.Unlock() return } - wasEmpty := ep.rcvBufSize == 0 + wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - packet := &packet{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) - packet.timestampNS = ep.stack.NowNanoseconds() + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.NowNanoseconds() - ep.rcvList.PushBack(packet) - ep.rcvBufSize += packet.data.Size() - - ep.rcvMu.Unlock() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { - ep.waiterQueue.Notify(waiter.EventIn) + e.waiterQueue.Notify(waiter.EventIn) } } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 168953dec..33bfb56cd 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -20,15 +20,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// saveData saves packet.data field. -func (p *packet) saveData() buffer.VectorisedView { +// saveData saves rawPacket.data field. +func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } -// loadData loads packet.data field. -func (p *packet) loadData(data buffer.VectorisedView) { +// loadData loads rawPacket.data field. +func (p *rawPacket) loadData(data buffer.VectorisedView) { // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point @@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) if err != nil { panic(err) } @@ -81,12 +81,14 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(err) + if ep.associated { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + panic(err) + } } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index 783c21e6b..f30aa2a4a 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,16 +17,19 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) -type factory struct{} +// EndpointFactory implements stack.RawFactory. +type EndpointFactory struct{} -// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } -func init() { - stack.RegisterUnassociatedFactory(factory{}) +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 1ee1a53f8..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "tcp_segment_list", @@ -27,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", @@ -43,12 +45,15 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", @@ -58,17 +63,9 @@ go_library( ], ) -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -229,7 +230,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n := newEndpoint(l.stack, netProto, nil) n.v6only = l.v6only - n.id = s.id + n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} @@ -241,8 +242,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() return nil, err } @@ -268,8 +276,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -288,9 +296,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -298,7 +305,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -311,14 +318,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.id] = n + l.pendingEndpoints[n.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.id) + delete(l.pendingEndpoints, n.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -350,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -360,12 +375,19 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -399,8 +421,19 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case header.TCPFlagSyn: + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -414,6 +447,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { @@ -421,6 +455,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // is full then drop the syn. if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -431,19 +466,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } - sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -451,11 +485,20 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // complete the connection at the time of retransmit if // the backlog has space. e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. @@ -505,6 +548,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } @@ -515,7 +559,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -526,6 +574,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } @@ -536,7 +589,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 21038a65a..cdd69f360 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -78,9 +80,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -142,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -194,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.mu.Lock() + h.ep.workerCleanup = true + h.ep.mu.Unlock() + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused } return nil @@ -228,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -238,6 +275,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() h.ep.state = StateSynRecv + ttl := h.ep.ttl h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), @@ -251,8 +289,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { SACKPermitted: rcvSynOpts.SACKPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -272,6 +312,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } + // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a + // sequence number outside of the window causes an ACK with the proper seq + // number and "After sending the acknowledgment, drop the unacceptable + // segment and return." + if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + return nil + } + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole @@ -296,7 +345,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -316,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -383,6 +436,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } // Either success (err == nil) or failure. return err } @@ -437,7 +495,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -460,7 +518,8 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -469,7 +528,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -579,26 +638,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } putOptions(options) - return err + return nil } -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err } + e.stats.SegmentsSent.Increment() + return nil +} +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { + optLen := len(opts) + hdr := &pkt.Header + packetSize := pkt.DataSize + off := pkt.DataOffset // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + pkt.TransportHeader = buffer.View(tcp) tcp.Encode(&header.TCPFields{ SrcPort: id.LocalPort, DstPort: id.RemotePort, @@ -610,7 +676,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { @@ -620,16 +686,87 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVV(data, xsum) + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } +} + +func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + // Allocate one big slice for all the headers. + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + buf := make([]byte, n*hdrSize) + pkts := make([]tcpip.PacketBuffer, n) + for i := range pkts { + pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) + } + + size := data.Size() + off := 0 + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + pkts[i].DataOffset = off + pkts[i].DataSize = packetSize + pkts[i].Data = data + buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) + off += packetSize + seq = seq.Add(seqnum.Size(packetSize)) + } + if ttl == 0 { + ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + } + + pkt := tcpip.PacketBuffer{ + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + DataOffset: 0, + DataSize: data.Size(), + Data: data, + } + buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) + + if ttl == 0 { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } r.Stats().TCP.SegmentsSent.Increment() if (flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } - - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) + return nil } // makeOptions makes an options slice. @@ -678,7 +815,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } @@ -727,10 +864,26 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError - e.hardError = err - if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + e.HardError = err + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } @@ -744,6 +897,99 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + switch e.state { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.transitionToStateCloseLocked() + e.HardError = tcpip.ErrAborted + e.mu.Unlock() + return false, nil + default: + e.mu.Unlock() + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + // handleSegments pulls segments from the queue and processes them. It returns // no error if the protocol loop should continue, an error otherwise. func (e *endpoint) handleSegments() *tcpip.Error { @@ -761,14 +1007,34 @@ func (e *endpoint) handleSegments() *tcpip.Error { } if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset + if ok, err := e.handleReset(s); !ok { + return err } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -777,7 +1043,33 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return nil + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -803,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -827,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -835,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -842,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -876,7 +1182,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -898,28 +1203,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.mu.Lock() e.state = StateError - e.hardError = err + e.HardError = err // Lock released below. epilogue() return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -927,7 +1217,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -958,13 +1247,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.transitionToStateCloseLocked() + e.mu.Unlock() + return nil }, }, { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1001,17 +1297,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1027,12 +1324,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != StateError { + if e.state != StateClose && e.state != StateError { + // Only block the worker if the endpoint + // is not in closed state or error state. close(e.drainDone) <-e.undrain } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1059,15 +1364,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1083,15 +1389,168 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.state = StateClose + e.stack.Stats().TCP.CurrentEstablished.Decrement() + e.transitionToStateCloseLocked() } + // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } + + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index c54610a87..dfaa4a559 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { } } -func testV4Connect(t *testing.T, c *context.Context) { +func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv4(t, b, synCheckers...) tcp := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv4(t, c.GetPacket(), ackCheckers...) // Wait for connection to be established. select { @@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { testV4Connect(t, c) } -func testV6Connect(t *testing.T, c *context.Context) { +func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt to IPv6 address. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetV6Packet() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv6(t, b, synCheckers...) tcp := header.TCP(header.IPv6(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv6(t, c.GetV6Packet(), ackCheckers...) // Wait for connection to be established. select { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..dd8b47cbe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "fmt" "math" "strings" @@ -26,8 +27,10 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -119,6 +122,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -170,6 +178,101 @@ type rcvBufAutoTuneParams struct { disabled bool } +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -178,6 +281,8 @@ type rcvBufAutoTuneParams struct { // // +stateify savable type endpoint struct { + EndpointInfo + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -186,9 +291,9 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. - stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber + stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -218,20 +323,30 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` + ttl uint8 v6only bool isConnectNotified bool // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -240,11 +355,6 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` - // hardError is meaningful only when state is stateError, it stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. hardError is protected by mu. - hardError *tcpip.Error `state:".(string)"` - // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -280,6 +390,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -315,7 +428,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -362,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -393,13 +512,49 @@ type endpoint struct { probe stack.TCPProbeFunc `state:"nosave"` // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address connectingAddress tcpip.Address // amss is the advertised MSS to the peer by this endpoint. amss uint16 + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -427,10 +582,15 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, @@ -443,29 +603,40 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } var rs ReceiveBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } var cs tcpip.CongestionControlOption - if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } var mrb tcpip.ModerateReceiveBufferOption - if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - if p := stack.GetTCPProbe(); p != nil { + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -552,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -564,14 +742,18 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -597,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -625,16 +805,19 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -645,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -731,11 +916,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.hardError + he := e.HardError e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } + e.stats.ReadErrors.InvalidEndpointState.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -744,6 +930,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RUnlock() + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } return v, tcpip.ControlMessages{}, err } @@ -787,7 +976,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { if !e.state.connected() { switch e.state { case StateError: - return 0, e.hardError + return 0, e.HardError default: return 0, tcpip.ErrClosedForSend } @@ -806,7 +995,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -818,50 +1007,57 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +1071,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. @@ -889,8 +1086,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.state; !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardError + return 0, tcpip.ControlMessages{}, e.HardError } + e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -899,6 +1097,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.state.connected() { + e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -946,62 +1145,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1065,9 +1211,87 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicID, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicID + return nil + } + } + return tcpip.ErrUnknownDevice + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = uint16(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -1082,6 +1306,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 return nil + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + return nil + case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 @@ -1110,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1150,6 +1386,45 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. return tcpip.ErrNoSuchFile + + case tcpip.IPv4TOSOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1176,8 +1451,29 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1198,25 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1246,6 +1523,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1255,7 +1542,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -1269,6 +1556,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.mu.RLock() @@ -1311,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 @@ -1333,13 +1632,31 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { @@ -1355,7 +1672,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -1369,7 +1686,12 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return e.connect(addr, true, true) + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err } // connect connects the endpoint to its peer. In the normal non-S/R case, the @@ -1378,14 +1700,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() connectingAddr := addr.Addr @@ -1405,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1414,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1430,29 +1747,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnecting case StateError: - return e.hardError + return e.HardError default: return tcpip.ErrInvalidEndpointState } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.id + origID := e.ID netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.id.LocalAddress = r.LocalAddress - e.id.RemoteAddress = r.RemoteAddress - e.id.RemotePort = addr.Port + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) if err != nil { return err } @@ -1461,20 +1778,38 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.id.LocalAddress == e.id.RemoteAddress - if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - if sameAddr && p == e.id.RemotePort { + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.Seed()) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.ID.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } - id := e.id + id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: - e.id = id + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.boundBindToDevice = e.bindToDevice return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1490,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1509,7 +1844,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.id + s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() } @@ -1517,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1537,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1554,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1569,17 +1905,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.id, nil) + s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1587,24 +1920,37 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1630,11 +1976,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { return err } @@ -1678,12 +2025,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1698,7 +2040,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return tcpip.ErrAlreadyBound } - e.bindAddress = addr.Addr + e.BindAddr = addr.Addr netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -1715,26 +2057,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos - e.id.LocalPort = port + e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }() + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1745,7 +2094,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } e.boundNICID = nic - e.id.LocalAddress = addr.Addr + e.ID.LocalAddress = addr.Addr } // Mark endpoint as bound. @@ -1760,8 +2109,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -1776,19 +2125,20 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, NIC: e.boundNICID, }, nil } // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() return } @@ -1796,27 +2146,34 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.csumValid { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() return } e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + e.stats.SegmentsReceived.Increment() if (s.flags & header.TCPFlagRst) != 0 { e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() } else { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() s.decRef() } } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -1860,6 +2217,7 @@ func (e *endpoint) readyToRead(s *segment) { // that a subsequent read of the segment will correctly trigger // a non-zero notification. if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } e.rcvList.PushBack(s) @@ -2012,7 +2370,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Copy EndpointID. e.mu.Lock() - s.ID = stack.TCPEndpointID(e.id) + s.ID = stack.TCPEndpointID(e.ID) e.mu.Unlock() // Copy endpoint rcv state. @@ -2105,11 +2463,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2119,7 +2473,7 @@ func (e *endpoint) initGSO() { gso.Type = stack.GSOTCPv6 gso.L3HdrLen = header.IPv6MinimumSize default: - panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } gso.NeedsCsum = true gso.CsumOffset = header.TCPChecksumOffset @@ -2127,6 +2481,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2135,6 +2501,37 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 831389ec7..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() { case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress + if len(e.BindAddr) == 0 { + e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -202,21 +209,31 @@ func (e *endpoint) Resume(s *stack.Stack) { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress + e.connectingAddress = e.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -236,7 +253,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } @@ -288,21 +309,21 @@ func (e *endpoint) loadLastError(s string) { } // saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { return "" } - return e.hardError.String() + return e.HardError.String() } // loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { +func (e *EndpointInfo) loadHardError(s string) { if s == "" { return } - e.hardError = loadError(s) + e.HardError = loadError(s) } var messageToError map[string]*tcpip.Error diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a13b2022..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -14,7 +14,7 @@ // Package tcp contains the implementation of the TCP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// activated on the stack by passing tcp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing tcp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -34,9 +35,6 @@ import ( ) const ( - // ProtocolName is the string representation of the tcp protocol name. - ProtocolName = "tcp" - // ProtocolNumber is the tcp protocol number. ProtocolNumber = header.TCPProtocolNumber @@ -57,12 +55,23 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP // protocol. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +type DelayEnabled bool + // SendBufferSizeOption allows the default, min and max send buffer sizes for // TCP endpoints to be queried or configured. type SendBufferSizeOption struct { @@ -87,11 +96,14 @@ const ( type protocol struct { mu sync.Mutex sackEnabled bool + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -100,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -129,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -150,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -168,6 +193,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -205,6 +236,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -219,6 +268,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *DelayEnabled: + p.mu.Lock() + *v = DelayEnabled(p.delayEnabled) + p.mu.Unlock() + return nil + case *SendBufferSizeOption: p.mu.Lock() *v = p.sendBufferSize @@ -249,18 +304,31 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, - congestionControl: ccReno, - availableCongestionControl: []string{ccReno, ccCubic}, - } - }) +// NewProtocol returns a TCP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -49,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -204,15 +209,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } @@ -253,25 +263,110 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { @@ -288,7 +383,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +410,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } @@ -99,8 +100,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 735edfe55..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -417,6 +424,7 @@ func (s *sender) resendSegment() { s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() // Run SetPipe() as per RFC 6675 section 5 Step 4.4 s.SetPipe() @@ -435,9 +443,32 @@ func (s *sender) retransmitTimerExpired() bool { } s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() + + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -445,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1166,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } @@ -1188,6 +1226,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { func (s *sender) sendSegment(seg *segment) *tcpip.Error { if !seg.xmitTime.IsZero() { s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..782d7b42c 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 4e7f1a740..afea124ec 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) { t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) // Acknowledge all pending data to recover point. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 32bb45224..e8fe4dab5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -84,7 +98,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,9 +111,12 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) } } @@ -122,6 +139,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + } } func TestTCPSegmentsSentIncrement(t *testing.T) { @@ -131,11 +151,14 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + } } func TestTCPResetsSentIncrement(t *testing.T) { @@ -197,17 +220,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -247,7 +271,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -255,6 +279,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -276,6 +307,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -287,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -299,7 +335,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -323,7 +359,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -344,14 +380,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -367,7 +403,14 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -387,12 +430,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -404,20 +459,133 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(10 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -465,11 +633,583 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + +func TestTOSV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + var v tcpip.IPv4TOSOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testV4Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestTrafficClassV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + } + + var v tcpip.IPv6TrafficClassOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + // Test the connection request. + testV6Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + +func TestRstOnSynSent(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + // Ensure that we've reached SynSent state + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Send a packet with a proper ACK and a RST flag to cause the socket + // to Error and close out + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -557,8 +1297,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -631,7 +1370,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -700,7 +1439,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -761,8 +1500,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -785,7 +1523,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -798,14 +1536,17 @@ func TestShutdownRead(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + } } func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -855,6 +1596,11 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("got data = %v, want = %v", v, data) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + } + // Check that we get an ACK for the newly non-zero window. checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -872,11 +1618,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -976,7 +1720,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1017,7 +1761,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1075,8 +1819,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1110,8 +1853,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1151,7 +1893,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1224,7 +1966,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1293,8 +2035,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1399,7 +2140,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1449,9 +2190,9 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1497,9 +2238,9 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1532,7 +2273,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1569,7 +2310,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } @@ -1579,7 +2320,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1695,16 +2436,44 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + func TestActiveSendMSSLessThanMTU(t *testing.T) { const maxPayload = 100 c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1727,7 +2496,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1871,7 +2640,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1973,7 +2742,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2004,13 +2773,27 @@ loop: t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2035,7 +2818,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2078,7 +2861,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2132,7 +2915,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2203,7 +2986,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2291,7 +3074,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2377,7 +3160,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2509,7 +3292,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2559,7 +3342,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2575,12 +3358,23 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2599,12 +3393,15 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2625,6 +3422,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { if got := stats.TCP.ChecksumErrors.Value(); got != want { t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } } func TestReceivedSegmentQueuing(t *testing.T) { @@ -2635,7 +3435,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2681,19 +3481,26 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2731,10 +3538,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -2751,7 +3557,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2762,7 +3568,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2772,11 +3578,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -2856,8 +3662,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2869,8 +3675,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2880,7 +3686,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } func TestDefaultBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2926,7 +3735,10 @@ func TestDefaultBufferSizes(t *testing.T) { } func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2945,37 +3757,96 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New([]string{ - ipv4.ProtocolName, - ipv6.ProtocolName, - }, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + ipv6.NewProtocol(), + }, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) id := loopback.New() if testing.Verbose() { @@ -3231,7 +4102,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3308,7 +4179,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3482,10 +4353,11 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -3575,19 +4447,43 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -3601,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply.w + // Receive the SYN-ACK reply. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -3634,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki return irs, iss } +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { @@ -3736,6 +4676,210 @@ func TestListenBacklogFull(t *testing.T) { } } +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + func TestListenSynRcvdQueueFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3767,7 +4911,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { SrcPort: context.TestPort, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: irs, RcvWnd: 30000, }) @@ -3878,7 +5022,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -3918,6 +5063,125 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSynRcvdBadSeqNumber(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.AckNum(uint32(irs) + 1), + checker.SeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. + + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + + if err != nil && err != tcpip.ErrWouldBlock { + t.Fatalf("Accept failed: %s", err) + } + + if err == tcpip.ErrWouldBlock { + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + pkt := c.GetPacket() + tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + if string(tcpHdr.Payload()) != data { + t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } +} + func TestPassiveConnectionAttemptIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4012,6 +5276,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + } we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4050,6 +5317,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + // Expect InvalidEndpointState errors on a read at this point. + if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { + t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + } + if err := ep.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -4081,6 +5356,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -4368,3 +5646,918 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, 0) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption int + }{ + {delayEnabled: false, wantDelayOption: 0}, + {delayEnabled: true, wantDelayOption: 1}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + } + gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index 19b0d31c5..b33ec2087 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -8,7 +8,7 @@ go_library( srcs = ["context.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ - "//:sandbox", + "//visibility:public", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 18c707a57..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -137,7 +137,10 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Allow minimum send/receive buffer sizes to be 1 during tests. if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { @@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { - id = sniffer.New(id) + wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } @@ -220,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -248,14 +260,15 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -291,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -308,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, + SrcAddr: src, + DstAddr: dst, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -326,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -339,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) +} + +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) } // SendAck sends an ACK packet. @@ -451,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -473,6 +508,13 @@ func (c *Context) GetV6Packet() []byte { // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of // the context. func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -483,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { PayloadLength: uint16(header.TCPMinimumSize + len(payload)), NextHeader: uint8(tcp.ProtocolNumber), HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. @@ -500,18 +542,20 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -584,12 +628,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -597,11 +637,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } @@ -1043,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) { func (c *Context) MSSWithoutOptions() uint16 { return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) } + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 4bec48c0f..43fcc27f0 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index ac2666f69..97e4d5825 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "udp_packet_list", @@ -33,6 +34,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", @@ -50,6 +52,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", @@ -57,11 +60,3 @@ go_test( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 66455ef46..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,13 +15,13 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -32,9 +32,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -50,6 +47,22 @@ const ( StateClosed ) +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -59,11 +72,13 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -77,20 +92,28 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID state EndpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID route stack.Route `state:"manual"` dstPort uint16 v6only bool + ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -105,6 +128,9 @@ type endpoint struct { // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` } // +stateify savable @@ -113,10 +139,13 @@ type multicastMembership struct { multicastAddr tcpip.Address } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.UDPProtocolNumber, + }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership // requests. @@ -134,9 +163,16 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -145,12 +181,14 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } e.multicastMemberships = nil @@ -190,6 +228,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -251,43 +290,61 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { - localAddr := e.id.LocalAddress +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { + localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -333,13 +390,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicID = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -351,7 +408,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, *to, netProto) if err != nil { return 0, nil, err } @@ -370,17 +427,25 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + + ttl := e.ttl + useDefaultTTL := ttl == 0 - ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { ttl = e.multicastTTL + // Multicast allows a 0 TTL. + useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -391,12 +456,17 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -410,6 +480,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -444,7 +519,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if e.bindNICID != 0 && e.bindNICID != nic { + if e.BindNICID != 0 && e.BindNICID != nic { return tcpip.ErrInvalidEndpointState } @@ -471,7 +546,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -488,7 +563,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -509,7 +584,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -531,7 +606,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrBadLocalAddress } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -548,12 +623,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicID, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicID + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 e.mu.Unlock() return nil + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil } return nil } @@ -570,7 +672,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -580,21 +695,9 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -608,6 +711,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) @@ -631,6 +740,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.MulticastLoopOption(v) return nil + case *tcpip.ReuseAddressOption: + *o = 0 + return nil + case *tcpip.ReusePortOption: e.mu.RLock() v := e.reusePort @@ -642,6 +755,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -657,6 +780,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -664,7 +799,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -687,14 +822,25 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } + if useDefaultTTL { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: hdr, + Data: data, + TransportHeader: buffer.View(udp), + }); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) + return nil } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if len(addr.Addr) == 0 { return netProto, nil } @@ -711,14 +857,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t } // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.id.LocalAddress) == 16 { + if !allowMismatch && len(e.ID.LocalAddress) == 16 { return 0, tcpip.ErrNetworkUnreachable } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -733,25 +879,34 @@ func (e *endpoint) Disconnect() *tcpip.Error { if e.state != StateConnected { return nil } - id := stack.TransportEndpointID{} + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) // Exclude ephemerally bound endpoints. - if e.bindNICID != 0 || e.id.LocalAddress == "" { + if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.id.LocalPort, - LocalAddress: e.id.LocalAddress, + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } e.state = StateBound } else { + if e.ID.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundPortFlags = ports.Flags{} + } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.id = id + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.ID = id + e.boundBindToDevice = btd e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -773,33 +928,33 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: case StateBound, StateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: e.id.LocalAddress, + LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, @@ -820,20 +975,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. - if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) } - e.id = id + e.ID = id + e.boundBindToDevice = btd e.route = r.Clone() e.dstPort = addr.Port - e.regNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -888,20 +1044,27 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + if e.ID.LocalPort == 0 { + flags := ports.Flags{ + LoadBalanced: e.reusePort, + // FIXME(b/129164367): Support SO_REUSEADDR. + MostRecent: false, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) if err != nil { - return id, err + return id, e.bindToDevice, err } + e.boundPortFlags = flags id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.boundPortFlags = ports.Flags{} } - return id, err + return id, e.bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -927,11 +1090,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -940,13 +1103,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } - e.id = id - e.regNICID = nicid + e.ID = id + e.boundBindToDevice = btd + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -971,7 +1135,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // Save the effective NICID generated by bindLocked. - e.bindNICID = e.regNICID + e.BindNICID = e.RegisterNICID return nil } @@ -981,10 +1145,15 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: addr, + Port: e.ID.LocalPort, }, nil } @@ -998,9 +1167,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -1024,42 +1193,52 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1070,7 +1249,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. @@ -1080,6 +1259,23 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index be46e6d4e..43fb047ed 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,7 +72,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } } @@ -93,13 +93,13 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } @@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) { // Our saved state had a port, but we don't actually have a // reservation. We need to remove the port from our state, but still // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id := e.ID + e.ID.LocalPort = 0 + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a9edc2c8d..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -74,15 +73,15 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } - ep.id = r.id + ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort - ep.regNICID = r.route.NICID() + ep.RegisterNICID = r.route.NICID() ep.state = StateConnected @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 068d9a272..259c3072a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -14,7 +14,7 @@ // Package udp contains the implementation of the UDP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// activated on the stack by passing udp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing udp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -30,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the udp protocol name. - ProtocolName = "udp" - // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber ) @@ -69,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -119,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -133,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -154,12 +159,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -167,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) } return true } @@ -182,8 +190,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{} - }) +// NewProtocol returns a UDP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 995d6e8a1..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -274,13 +274,17 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } @@ -352,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw if p.Proto != flow.netProto() { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) h := flow.header4Tuple(outgoing) checkers := append( @@ -380,18 +384,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -405,10 +412,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -419,15 +432,22 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -457,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { @@ -472,94 +497,67 @@ func newMinPayload(minSize int) []byte { return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) - - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -570,6 +568,9 @@ func testRead(c *testContext, flow testFlow) { c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -577,25 +578,55 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) +} + +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -767,13 +798,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -788,8 +823,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4, false /* expectReadError */) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. testRead(c, flow) + testRead(c, unicastV4) }) } } @@ -798,7 +856,8 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -806,6 +865,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -831,6 +891,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -848,7 +910,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -917,6 +979,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1179,6 +1245,109 @@ func TestTTL(t *testing.T) { } } +func TestSetTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + ep.Close() + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } + }) + } +} + +func TestTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv4TOSOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + +func TestTOSV6(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv6TrafficClassOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + c.t.Errorf("SetSockOpt failed: %s", err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1284,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1358,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1392,3 +1561,117 @@ func TestV6UnknownDestination(t *testing.T) { }) } } + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 98d51cc69..6afdb29b7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index cbd92fc05..8f6f180e5 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b7f505a84..b6bbb0ea2 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 9173dfd0f..0427bc41f 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "waiter_list", @@ -33,11 +34,3 @@ go_test( ], embed = [":waiter"], ) - -filegroup( - name = "autogen", - srcs = [ - "waiter_list.go", - ], - visibility = ["//:sandbox"], -) |