diff options
Diffstat (limited to 'pkg')
381 files changed, 24499 insertions, 3265 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 7c17109a6..9553f164d 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -23,6 +23,8 @@ go_library( "exec.go", "fcntl.go", "file.go", + "file_amd64.go", + "file_arm64.go", "fs.go", "futex.go", "inotify.go", @@ -55,6 +57,7 @@ go_library( "uio.go", "utsname.go", "wait.go", + "xattr.go", ], importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index f78315ebf..6663a199c 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -16,15 +16,17 @@ package linux // Commands from linux/fcntl.h. const ( - F_DUPFD = 0x0 - F_GETFD = 0x1 - F_SETFD = 0x2 - F_GETFL = 0x3 - F_SETFL = 0x4 - F_SETLK = 0x6 - F_SETLKW = 0x7 - F_SETOWN = 0x8 - F_GETOWN = 0x9 + F_DUPFD = 0 + F_GETFD = 1 + F_SETFD = 2 + F_GETFL = 3 + F_SETFL = 4 + F_SETLK = 6 + F_SETLKW = 7 + F_SETOWN = 8 + F_GETOWN = 9 + F_SETOWN_EX = 15 + F_GETOWN_EX = 16 F_DUPFD_CLOEXEC = 1024 + 6 F_SETPIPE_SZ = 1024 + 7 F_GETPIPE_SZ = 1024 + 8 @@ -32,9 +34,9 @@ const ( // Commands for F_SETLK. const ( - F_RDLCK = 0x0 - F_WRLCK = 0x1 - F_UNLCK = 0x2 + F_RDLCK = 0 + F_WRLCK = 1 + F_UNLCK = 2 ) // Flags for fcntl. @@ -42,7 +44,7 @@ const ( FD_CLOEXEC = 00000001 ) -// Lock structure for F_SETLK. +// Flock is the lock structure for F_SETLK. type Flock struct { Type int16 Whence int16 @@ -52,3 +54,16 @@ type Flock struct { Pid int32 _ [4]byte } + +// Flags for F_SETOWN_EX and F_GETOWN_EX. +const ( + F_OWNER_TID = 0 + F_OWNER_PID = 1 + F_OWNER_PGRP = 2 +) + +// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +type FOwnerEx struct { + Type int32 + PID int32 +} diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 257f67222..16791d03e 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -144,9 +144,13 @@ const ( ModeCharacterDevice = S_IFCHR ModeNamedPipe = S_IFIFO - ModeSetUID = 04000 - ModeSetGID = 02000 - ModeSticky = 01000 + S_ISUID = 04000 + S_ISGID = 02000 + S_ISVTX = 01000 + + ModeSetUID = S_ISUID + ModeSetGID = S_ISGID + ModeSticky = S_ISVTX ModeUserAll = 0700 ModeUserRead = 0400 @@ -186,25 +190,6 @@ const ( RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC ) -// Stat represents struct stat. -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - // SizeOfStat is the size of a Stat struct. var SizeOfStat = binary.Size(Stat{}) @@ -301,6 +286,29 @@ func (m FileMode) String() string { return strings.Join(s, "|") } +// DirentType maps file types to dirent types appropriate for (struct +// dirent)::d_type. +func (m FileMode) DirentType() uint8 { + switch m.FileType() { + case ModeSocket: + return DT_SOCK + case ModeSymlink: + return DT_LNK + case ModeRegular: + return DT_REG + case ModeBlockDevice: + return DT_BLK + case ModeDirectory: + return DT_DIR + case ModeCharacterDevice: + return DT_CHR + case ModeNamedPipe: + return DT_FIFO + default: + return DT_UNKNOWN + } +} + var modeExtraBits = abi.FlagSet{ { Flag: ModeSetUID, diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go new file mode 100644 index 000000000..74c554be6 --- /dev/null +++ b/pkg/abi/linux/file_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go new file mode 100644 index 000000000..f16c07589 --- /dev/null +++ b/pkg/abi/linux/file_arm64.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [2]int32 +} diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index b416e3472..2c652baa2 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -92,3 +92,10 @@ const ( SYNC_FILE_RANGE_WRITE = 2 SYNC_FILE_RANGE_WAIT_AFTER = 4 ) + +// Flag argument to renameat2(2), from include/uapi/linux/fs.h. +const ( + RENAME_NOREPLACE = (1 << 0) // Don't overwrite target. + RENAME_EXCHANGE = (1 << 1) // Exchange src and dst. + RENAME_WHITEOUT = (1 << 2) // Whiteout src. +) diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 152f6b144..3898d2314 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -325,3 +325,9 @@ const ( RTA_SPORT = 28 RTA_DPORT = 29 ) + +// Route flags, from include/uapi/linux/route.h. +const ( + RTF_GATEWAY = 0x2 + RTF_UP = 0x1 +) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d5b731390..766ee4014 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -256,6 +256,17 @@ type SockAddrInet6 struct { Scope_id uint32 } +// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +type SockAddrLink struct { + Family uint16 + Protocol uint16 + InterfaceIndex int32 + ARPHardwareType uint16 + PacketType byte + HardwareAddrLen byte + HardwareAddr [8]byte +} + // UnixPathMax is the maximum length of the path in an AF_UNIX socket. // // From uapi/linux/un.h. @@ -278,6 +289,7 @@ type SockAddr interface { func (s *SockAddrInet) implementsSockAddr() {} func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrLink) implementsSockAddr() {} func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} @@ -410,6 +422,15 @@ type ControlMessageRights []int32 // ControlMessageRights. const SizeOfControlMessageRight = 4 +// SizeOfControlMessageInq is the size of a TCP_INQ control message. +const SizeOfControlMessageInq = 4 + +// SizeOfControlMessageTOS is the size of an IP_TOS control message. +const SizeOfControlMessageTOS = 1 + +// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. +const SizeOfControlMessageTClass = 4 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go new file mode 100644 index 000000000..a3b6406fa --- /dev/null +++ b/pkg/abi/linux/xattr.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Constants for extended attributes. +const ( + XATTR_NAME_MAX = 255 + XATTR_SIZE_MAX = 65536 + + XATTR_CREATE = 1 + XATTR_REPLACE = 2 + + XATTR_USER_PREFIX = "user." + XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX) +) diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 5d61dc2ff..d37047368 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -183,6 +183,33 @@ const ( X86FeatureAVX512VBMI X86FeatureUMIP X86FeaturePKU + X86FeatureOSPKE + X86FeatureWAITPKG + X86FeatureAVX512_VBMI2 + _ // ecx bit 7 is reserved + X86FeatureGFNI + X86FeatureVAES + X86FeatureVPCLMULQDQ + X86FeatureAVX512_VNNI + X86FeatureAVX512_BITALG + X86FeatureTME + X86FeatureAVX512_VPOPCNTDQ + _ // ecx bit 15 is reserved + X86FeatureLA57 + // ecx bits 17-21 are reserved + _ + _ + _ + _ + _ + X86FeatureRDPID + // ecx bits 23-24 are reserved + _ + _ + X86FeatureCLDEMOTE + _ // ecx bit 26 is reserved + X86FeatureMOVDIRI + X86FeatureMOVDIR64B ) // Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX. @@ -353,9 +380,24 @@ var x86FeatureStrings = map[Feature]string{ X86FeatureAVX512VL: "avx512vl", // Block 3. - X86FeatureAVX512VBMI: "avx512vbmi", - X86FeatureUMIP: "umip", - X86FeaturePKU: "pku", + X86FeatureAVX512VBMI: "avx512vbmi", + X86FeatureUMIP: "umip", + X86FeaturePKU: "pku", + X86FeatureOSPKE: "ospke", + X86FeatureWAITPKG: "waitpkg", + X86FeatureAVX512_VBMI2: "avx512_vbmi2", + X86FeatureGFNI: "gfni", + X86FeatureVAES: "vaes", + X86FeatureVPCLMULQDQ: "vpclmulqdq", + X86FeatureAVX512_VNNI: "avx512_vnni", + X86FeatureAVX512_BITALG: "avx512_bitalg", + X86FeatureTME: "tme", + X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq", + X86FeatureLA57: "la57", + X86FeatureRDPID: "rdpid", + X86FeatureCLDEMOTE: "cldemote", + X86FeatureMOVDIRI: "movdiri", + X86FeatureMOVDIR64B: "movdir64b", // Block 4. X86FeatureXSAVEOPT: "xsaveopt", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 71f2abc83..0b4b7cc44 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -25,6 +25,7 @@ go_library( proto_library( name = "eventchannel_proto", srcs = ["event.proto"], + visibility = ["//:sandbox"], ) go_proto_library( diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index c7f549428..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -8,9 +8,6 @@ go_library( srcs = ["fd.go"], importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], - deps = [ - "//pkg/unet", - ], ) go_test( diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 7691b477b..83bcfe220 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -22,8 +22,6 @@ import ( "runtime" "sync/atomic" "syscall" - - "gvisor.dev/gvisor/pkg/unet" ) // ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It @@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) { return New(f), nil } -// DialUnix connects to a Unix Domain Socket and return the file descriptor. -func DialUnix(path string) (*FD, error) { - socket, err := unet.Connect(path, false) - return New(socket.FD()), err -} - // Close closes the file descriptor contained in the FD. // // Close is safe to call multiple times, but will return an error after the diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 5643d5f26..e590a71ba 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//third_party/gvsync", + "//pkg/syncutil", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 8390915a2..e7c3a3a0b 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error { return nil case epsBlocked | epsShutdown: atomic.AddInt32(&ep.ctrl.state, -epsBlocked) - return shutdownError{} + return ShutdownError{} default: // Most likely due to ep.enterFutexWait() being called concurrently // from multiple goroutines. diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 386cee42c..3cdb576e1 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() { // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), // ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer -// Endpoint, to unblock and return errors. It does not wait for concurrent -// calls to return. Successive calls to Shutdown have no effect. +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool { return atomic.LoadUint32(&ep.shutdown) != 0 } -type shutdownError struct{} +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} // Error implements error.Error. -func (shutdownError) Error() string { +func (ShutdownError) Error() string { return "flipcall connection shutdown" } diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index a37952637..27b8939fc 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if gvsync.RaceEnabled { - gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + if syncutil.RaceEnabled { + syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if gvsync.RaceEnabled { - gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if syncutil.RaceEnabled { + syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index b127a2bbb..168c1ccff 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error { if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { switch cs := atomic.LoadUint32(ep.connState()); cs { case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } @@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error { return nil case ep.inactiveState: if ep.isShutdownLocally() { - return shutdownError{} + return ShutdownError{} } if err := ep.futexWaitConnState(ep.inactiveState); err != nil { return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } continue case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index a6cc0617e..de9357389 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -17,7 +17,6 @@ package p9 import ( "fmt" "io" - "runtime" "sync/atomic" "syscall" @@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) { // newFile returns a new client file. func (c *Client) newFile(fid FID) *clientFile { - cf := &clientFile{ + return &clientFile{ client: c, fid: fid, } - - // Make sure the file is closed. - runtime.SetFinalizer(cf, (*clientFile).Close) - - return cf } // clientFile is provided to clients. @@ -192,7 +186,6 @@ func (c *clientFile) Remove() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the remove message. if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil { @@ -214,7 +207,6 @@ func (c *clientFile) Close() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return syscall.EBADF } - runtime.SetFinalizer(c, nil) // Send the close message. if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 907445e15..96d1f2a8e 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -116,7 +116,7 @@ type File interface { // N.B. The server must resolve any lazy paths when open is called. // After this point, read and write may be called on files with no // deletion check, so resolving in the data path is not viable. - Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // Read reads from this file. Open must be called first. // diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index ba9a55d6d..b9582c07f 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -257,7 +257,6 @@ func CanOpen(mode FileMode) bool { // handle implements handler.handle. func (t *Tlopen) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -272,15 +271,15 @@ func (t *Tlopen) handle(cs *connState) message { return newErr(syscall.EINVAL) } - // Are flags valid? - flags := t.Flags &^ OpenFlagsIgnoreMask - if flags&^OpenFlagsModeMask != 0 { - return newErr(syscall.EINVAL) - } - - // Is this an attempt to open a directory as writable? Don't accept. - if ref.mode.IsDir() && flags != ReadOnly { - return newErr(syscall.EINVAL) + if ref.mode.IsDir() { + // Directory must be opened ReadOnly. + if t.Flags&OpenFlagsModeMask != ReadOnly { + return newErr(syscall.EISDIR) + } + // Directory not truncatable. + if t.Flags&OpenTruncate != 0 { + return newErr(syscall.EISDIR) + } } var ( @@ -294,7 +293,6 @@ func (t *Tlopen) handle(cs *connState) message { return syscall.EINVAL } - // Do the open. osFile, qid, ioUnit, err = ref.file.Open(t.Flags) return err }); err != nil { @@ -311,12 +309,10 @@ func (t *Tlopen) handle(cs *connState) message { } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return nil, syscall.EBADF @@ -390,12 +386,10 @@ func (t *Tsymlink) handle(cs *connState) message { } func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -426,19 +420,16 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { // handle implements handler.handle. func (t *Tlink) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.Target) if !ok { return newErr(syscall.EBADF) @@ -467,7 +458,6 @@ func (t *Tlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Trenameat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.OldName); err != nil { return newErr(err) } @@ -475,14 +465,12 @@ func (t *Trenameat) handle(cs *connState) message { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.OldDirectory) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the other FID. refTarget, ok := cs.LookupFID(t.NewDirectory) if !ok { return newErr(syscall.EBADF) @@ -523,12 +511,10 @@ func (t *Trenameat) handle(cs *connState) message { // handle implements handler.handle. func (t *Tunlinkat) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -577,19 +563,16 @@ func (t *Tunlinkat) handle(cs *connState) message { // handle implements handler.handle. func (t *Trename) handle(cs *connState) message { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return newErr(err) } - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) } defer ref.DecRef() - // Lookup the target. refTarget, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -641,7 +624,6 @@ func (t *Trename) handle(cs *connState) message { // handle implements handler.handle. func (t *Treadlink) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -669,7 +651,6 @@ func (t *Treadlink) handle(cs *connState) message { // handle implements handler.handle. func (t *Tread) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -708,7 +689,6 @@ func (t *Tread) handle(cs *connState) message { // handle implements handler.handle. func (t *Twrite) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -747,12 +727,10 @@ func (t *Tmknod) handle(cs *connState) message { } func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -791,12 +769,10 @@ func (t *Tmkdir) handle(cs *connState) message { } func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { - // Don't allow complex names. if err := checkSafeName(t.Name); err != nil { return nil, err } - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return nil, syscall.EBADF @@ -827,7 +803,6 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { // handle implements handler.handle. func (t *Tgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -856,7 +831,6 @@ func (t *Tgetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tsetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -883,7 +857,6 @@ func (t *Tsetattr) handle(cs *connState) message { // handle implements handler.handle. func (t *Tallocate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -917,7 +890,6 @@ func (t *Tallocate) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrwalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -930,7 +902,6 @@ func (t *Txattrwalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Txattrcreate) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -943,7 +914,6 @@ func (t *Txattrcreate) handle(cs *connState) message { // handle implements handler.handle. func (t *Treaddir) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.Directory) if !ok { return newErr(syscall.EBADF) @@ -977,7 +947,6 @@ func (t *Treaddir) handle(cs *connState) message { // handle implements handler.handle. func (t *Tfsync) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1001,7 +970,6 @@ func (t *Tfsync) handle(cs *connState) message { // handle implements handler.handle. func (t *Tstatfs) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1192,7 +1160,6 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI // handle implements handler.handle. func (t *Twalk) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1213,7 +1180,6 @@ func (t *Twalk) handle(cs *connState) message { // handle implements handler.handle. func (t *Twalkgetattr) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1270,7 +1236,6 @@ func (t *Tumknod) handle(cs *connState) message { // handle implements handler.handle. func (t *Tlconnect) handle(cs *connState) message { - // Lookup the FID. ref, ok := cs.LookupFID(t.FID) if !ok { return newErr(syscall.EBADF) @@ -1303,7 +1268,6 @@ func (t *Tchannel) handle(cs *connState) message { return newErr(err) } - // Lookup the given channel. ch := cs.lookupChannel(t.ID) if ch == nil { return newErr(syscall.ENOSYS) diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 25530adca..d3090535a 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -32,21 +32,21 @@ import ( type OpenFlags uint32 const ( - // ReadOnly is a Topen and Tcreate flag indicating read-only mode. + // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode. ReadOnly OpenFlags = 0 - // WriteOnly is a Topen and Tcreate flag indicating write-only mode. + // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode. WriteOnly OpenFlags = 1 - // ReadWrite is a Topen flag indicates read-write mode. + // ReadWrite is a Tlopen flag indicates read-write mode. ReadWrite OpenFlags = 2 // OpenFlagsModeMask is a mask of valid OpenFlags mode bits. OpenFlagsModeMask OpenFlags = 3 - // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen. - // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h. - OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000 + // OpenTruncate is a Tlopen flag indicating that the opened file should be + // truncated. + OpenTruncate OpenFlags = 01000 ) // ConnectFlags is the mode passed to Connect operations. @@ -71,25 +71,32 @@ const ( // OSFlags converts a p9.OpenFlags to an int compatible with open(2). func (o OpenFlags) OSFlags() int { - return int(o & OpenFlagsModeMask) + // "flags contains Linux open(2) flags bits" - 9P2000.L + return int(o) } // String implements fmt.Stringer. func (o OpenFlags) String() string { - switch o { + var buf strings.Builder + switch mode := o & OpenFlagsModeMask; mode { case ReadOnly: - return "ReadOnly" + buf.WriteString("ReadOnly") case WriteOnly: - return "WriteOnly" + buf.WriteString("WriteOnly") case ReadWrite: - return "ReadWrite" - case OpenFlagsModeMask: - return "OpenFlagsModeMask" - case OpenFlagsIgnoreMask: - return "OpenFlagsIgnoreMask" + buf.WriteString("ReadWrite") default: - return "UNDEFINED" + fmt.Fprintf(&buf, "%#o", mode) } + otherFlags := o &^ OpenFlagsModeMask + if otherFlags&OpenTruncate != 0 { + buf.WriteString("|OpenTruncate") + otherFlags &^= OpenTruncate + } + if otherFlags != 0 { + fmt.Fprintf(&buf, "|%#o", otherFlags) + } + return buf.String() } // Tag is a message tag. @@ -814,7 +821,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { attr.Mode = FileMode(s.Mode) } if req.NLink { - attr.NLink = s.Nlink + attr.NLink = uint64(s.Nlink) } if req.UID { attr.UID = UID(s.Uid) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 8bbdb2488..6e758148d 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1044,11 +1044,11 @@ func TestReaddir(t *testing.T) { if _, err := f.Readdir(0, 1); err != syscall.EINVAL { t.Errorf("readdir got %v, wanted EINVAL", err) } - if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } - if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL { - t.Errorf("readdir got %v, wanted EINVAL", err) + if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR { + t.Errorf("readdir got %v, wanted EISDIR", err) } backend.EXPECT().Open(p9.ReadOnly).Times(1) if _, _, _, err := f.Open(p9.ReadOnly); err != nil { @@ -1065,75 +1065,93 @@ func TestReaddir(t *testing.T) { func TestOpen(t *testing.T) { type openTest struct { name string - mode p9.OpenFlags + flags p9.OpenFlags err error match func(p9.FileMode) bool } cases := []openTest{ { - name: "invalid", - mode: ^p9.OpenFlagsModeMask, - err: syscall.EINVAL, - match: func(p9.FileMode) bool { return true }, - }, - { name: "not-openable-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "not-openable-read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, err: syscall.EINVAL, match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) }, }, { name: "directory-read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-read-write", - mode: p9.ReadWrite, - err: syscall.EINVAL, + flags: p9.ReadWrite, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "directory-write-only", - mode: p9.WriteOnly, - err: syscall.EINVAL, + flags: p9.WriteOnly, + err: syscall.EISDIR, match: func(mode p9.FileMode) bool { return mode.IsDir() }, }, { name: "read-only", - mode: p9.ReadOnly, + flags: p9.ReadOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) }, }, { name: "write-only", - mode: p9.WriteOnly, + flags: p9.WriteOnly, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, { name: "read-write", - mode: p9.ReadWrite, + flags: p9.ReadWrite, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "directory-read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: syscall.EISDIR, + match: func(mode p9.FileMode) bool { return mode.IsDir() }, + }, + { + name: "read-only-truncate", + flags: p9.ReadOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "write-only-truncate", + flags: p9.WriteOnly | p9.OpenTruncate, + err: nil, + match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, + }, + { + name: "read-write-truncate", + flags: p9.ReadWrite | p9.OpenTruncate, err: nil, match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() }, }, } - // Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + // Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // - only works on Regular, NamedPipe, BLockDevice, CharacterDevice // - returning a file works as expected for name := range newTypeMap(nil) { @@ -1171,25 +1189,25 @@ func TestOpen(t *testing.T) { // Attempt the given open. if tc.err != nil { // We expect an error, just test and return. - if _, _, _, err := f.Open(tc.mode); err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + if _, _, _, err := f.Open(tc.flags); err != tc.err { + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return } // Run an FD test, since we expect success. fdTest(t, func(send *fd.FD) *fd.FD { - backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1) - recv, _, _, err := f.Open(tc.mode) + backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1) + recv, _, _, err := f.Open(tc.flags) if err != tc.err { - t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err) + t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err) } return recv }) // If the open was successful, attempt another one. - if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL { - t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err) + if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL { + t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err) } // Ensure that all illegal operations fail. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index e717e6161..40b8fa023 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) { go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() if err := res.service(cs); err != nil { - log.Warningf("p9.channel.service: %v", err) + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } } }() } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index f1ffdd23a..36a694c58 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 8 + highestSupportedVersion uint32 = 9 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -155,3 +155,9 @@ func versionSupportsTallocate(v uint32) bool { func versionSupportsFlipcall(v uint32) bool { return v >= 8 } + +// VersionSupportsOpenTruncateFlag returns true if version v supports +// passing the OpenTruncate flag to Tlopen. +func VersionSupportsOpenTruncateFlag(v uint32) bool { + return v >= 9 +} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 30ec8e6e2..38cea9be3 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index e340d9f98..4f4b70fef 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index c7503f2cc..fc36efa23 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -199,6 +199,10 @@ func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string { return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx) } +func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string { + return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name) +} + func checkArgsLabel(sysno uintptr) string { return fmt.Sprintf("checkArgs_%v", sysno) } @@ -223,6 +227,19 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case GreaterThan: + labelGood := fmt.Sprintf("gt%v", i) + high, low := uint32(a>>32), uint32(a) + // assert arg_high < high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high > high + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + // arg_low < low + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index 29eec8db1..84c841d7f 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -49,6 +49,9 @@ func (a AllowAny) String() (s string) { // AllowValue specifies a value that needs to be strictly matched. type AllowValue uintptr +// GreaterThan specifies a value that needs to be strictly smaller. +type GreaterThan uintptr + func (a AllowValue) String() (s string) { return fmt.Sprintf("%#x ", uintptr(a)) } diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 353686ed3..abbee7051 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -340,6 +340,54 @@ func TestBasic(t *testing.T) { }, }, }, + { + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThan(0xf), + GreaterThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + specs: []spec{ + { + desc: "GreaterThan: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument allowed", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (equal)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "GreaterThan2: Syscall argument disallowed (smaller)", + data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, } { instrs, err := BuildProgram(test.ruleSets, test.defaultAction) if err != nil { diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index 48413f1fb..da6b9eaaf 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -38,7 +38,7 @@ func main() { syscall.SYS_CLONE: {}, syscall.SYS_CLOSE: {}, syscall.SYS_DUP: {}, - syscall.SYS_DUP2: {}, + syscall.SYS_DUP3: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_WAIT: {}, diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 9e7db8b30..67daa6c24 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs()) return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return c.Native(0), nil } @@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { _, err := c.PtraceSetRegs(bytes.NewBuffer(buf)) return err } - // TODO(b/34088053): debug registers + // Note: x86 debug registers are missing. return nil } diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go index dfd62cbdb..23e009ef3 100644 --- a/pkg/sentry/context/context.go +++ b/pkg/sentry/context/context.go @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package context defines the sentry's Context type. +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. package context import ( + "context" + "time" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { type Context interface { log.Logger amutex.Sleeper + context.Context // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate @@ -72,19 +83,36 @@ type Context interface { // AddressSpace is activated. Normally activate is the same value as the // deactivate parameter passed to UninterruptibleSleepStart. UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} - // Value returns the value associated with this Context for key, or nil if - // no value is associated with key. Successive calls to Value with the same - // key returns the same result. - // - // A key identifies a specific value in a Context. Functions that wish to - // retrieve values from Context typically allocate a key in a global - // variable then use that key as the argument to Context.Value. A key can - // be any type that supports equality; packages should define keys as an - // unexported type to avoid collisions. - Value(key interface{}) interface{} +// Err returns nil. +func (NoopSleeper) Err() error { + return nil } +// logContext implements basic logging. type logContext struct { log.Logger NoopSleeper @@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} { return nil } -// NoopSleeper is a noop implementation of amutex.Sleeper and -// Context.UninterruptibleSleep* methods for anonymous embedding in other types -// that do not want to notify kernel.Task about sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - // bgContext is the context returned by context.Background. var bgContext = &logContext{Logger: log.Log()} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 1f78d54a2..e1f2fea60 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -22,6 +22,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/urpc" ) @@ -56,6 +57,9 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD + + // Kernel is the kernel under profile. + Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a @@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { return err } + // Ensure all trace contexts are registered. + p.Kernel.RebuildTraceContexts() + p.traceFile = output return nil } @@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not start") + return errors.New("Execution tracing not started") } + // Similarly to the case above, if tasks have not ended traces, we will + // lose information. Thus we need to rebuild the tasks in order to have + // complete information. This will not lose information if multiple + // traces are overlapping. + p.Kernel.RebuildTraceContexts() + trace.Stop() p.traceFile.Close() p.traceFile = nil diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index c35faeb4c..ced51c66c 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error { } // Process contains information about a single process in a Sandbox. -// TODO(b/117881927): Implement TTY field. type Process struct { UID auth.KUID `json:"uid"` PID kernel.ThreadID `json:"pid"` // Parent PID - PPID kernel.ThreadID `json:"ppid"` + PPID kernel.ThreadID `json:"ppid"` + Threads []kernel.ThreadID `json:"threads"` // Processor utilization C int32 `json:"c"` + // TTY name of the process. Will be of the form "pts/N" if there is a + // TTY, or "?" if there is not. + TTY string `json:"tty"` // Start time STime string `json:"stime"` // CPU time @@ -285,18 +288,19 @@ type Process struct { } // ProcessListToTable prints a table with the following format: -// UID PID PPID C STIME TIME CMD -// 0 1 0 0 14:04 505262ns tail +// UID PID PPID C TTY STIME TIME CMD +// 0 1 0 0 pty/4 14:04 505262ns tail func ProcessListToTable(pl []*Process) string { var buf bytes.Buffer tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0) - fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD") + fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD") for _, d := range pl { - fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s", + fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s", d.UID, d.PID, d.PPID, d.C, + d.TTY, d.STime, d.Time, d.Cmd) @@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string { // ProcessListToJSON will return the JSON representation of ps. func ProcessListToJSON(pl []*Process) (string, error) { - b, err := json.Marshal(pl) + b, err := json.MarshalIndent(pl, "", " ") if err != nil { return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err) } @@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ts := k.TaskSet() now := k.RealtimeClock().Now() for _, tg := range ts.Root.ThreadGroups() { - pid := tg.PIDNamespace().IDOfThreadGroup(tg) + pidns := tg.PIDNamespace() + pid := pidns.IDOfThreadGroup(tg) + // If tg has already been reaped ignore it. if pid == 0 { continue @@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error { ppid := kernel.ThreadID(0) if p := tg.Leader().Parent(); p != nil { - ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup()) + ppid = pidns.IDOfThreadGroup(p.ThreadGroup()) } + threads := tg.MemberIDs(pidns) *out = append(*out, &Process{ - UID: tg.Leader().Credentials().EffectiveKUID, - PID: pid, - PPID: ppid, - STime: formatStartTime(now, tg.Leader().StartTime()), - C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), - Time: tg.CPUStats().SysTime.String(), - Cmd: tg.Leader().Name(), + UID: tg.Leader().Credentials().EffectiveKUID, + PID: pid, + PPID: ppid, + Threads: threads, + STime: formatStartTime(now, tg.Leader().StartTime()), + C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now), + Time: tg.CPUStats().SysTime.String(), + Cmd: tg.Leader().Name(), + TTY: ttyName(tg.TTY()), }) } sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID }) @@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 { } return int32(percentCPU) } + +func ttyName(tty *kernel.TTY) string { + if tty == nil { + return "?" + } + return fmt.Sprintf("pts/%d", tty.Index) +} diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go index d8ada2694..0a88459b2 100644 --- a/pkg/sentry/control/proc_test.go +++ b/pkg/sentry/control/proc_test.go @@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) { }{ { pl: []*Process{}, - expected: "UID PID PPID C STIME TIME CMD", + expected: "UID PID PPID C TTY STIME TIME CMD", }, { pl: []*Process{ @@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) { PID: 0, PPID: 0, C: 0, + TTY: "?", STime: "0", Time: "0", Cmd: "zero", @@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) { PID: 1, PPID: 1, C: 1, + TTY: "pts/4", STime: "1", Time: "1", Cmd: "one", }, }, - expected: `UID PID PPID C STIME TIME CMD -0 0 0 0 0 0 zero -1 1 1 1 1 1 one`, + expected: `UID PID PPID C TTY STIME TIME CMD +0 0 0 0 ? 0 0 zero +1 1 1 1 pts/4 1 1 one`, }, } diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 378602cc9..c035ffff7 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,9 +68,9 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go index 8e4d839e1..577445148 100644 --- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go index 0fab876a9..4338ae1fa 100644 --- a/pkg/sentry/fs/flags.go +++ b/pkg/sentry/fs/flags.go @@ -64,6 +64,10 @@ type FileFlags struct { // NonSeekable indicates that file.offset isn't used. NonSeekable bool + + // Truncate indicates that the file should be truncated before opened. + // This is only applicable if the file is regular. + Truncate bool } // SettableFileFlags is a subset of FileFlags above that can be changed @@ -118,6 +122,9 @@ func (f FileFlags) ToLinux() (mask uint) { if f.LargeFile { mask |= linux.O_LARGEFILE } + if f.Truncate { + mask |= linux.O_TRUNC + } switch { case f.Read && f.Write: diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index c2fbb4be9..bb8312849 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -28,8 +28,14 @@ func (f *fileOperations) afterLoad() { // Manually load the open handles. var err error + + // The file may have been opened with Truncate, but we don't + // want to re-open it with Truncate or we will lose data. + flags := f.flags + flags.Truncate = false + // TODO(b/38173783): Context is not plumbed to save/restore. - f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps) + f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) } diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 39c8ec33d..b86c49b39 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -64,7 +64,7 @@ func (h *handles) DecRef() { }) } -func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) { +func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) { _, newFile, err := file.walk(ctx, nil) if err != nil { return nil, err @@ -81,6 +81,9 @@ func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*han default: panic("impossible fs.FileFlags") } + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) { + p9flags |= p9.OpenTruncate + } hostFile, _, _, err := newFile.open(ctx, p9flags) if err != nil { diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 99910388f..91263ebdc 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -180,7 +180,7 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) // given flags. func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) { if !i.canShareHandles() { - return newHandles(ctx, i.file, flags) + return newHandles(ctx, i.s.client, i.file, flags) } i.handlesMu.Lock() @@ -201,19 +201,25 @@ func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cac // whether previously open read handle was recreated. Host mappings must be // invalidated if so. func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) { - // Do we already have usable shared handles? - if flags.Write { + // Check if we are able to use cached handles. + if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) { + // If we are truncating (and the gofer supports it), then we + // always need a new handle. Don't return one from the cache. + } else if flags.Write { if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) { + // File is opened for writing, and we have cached write + // handles that we can use. i.writeHandles.IncRef() return i.writeHandles, false, nil } } else if i.readHandles != nil { + // File is opened for reading and we have cached handles. i.readHandles.IncRef() return i.readHandles, false, nil } - // No; get new handles and cache them for future sharing. - h, err := newHandles(ctx, i.file, flags) + // Get new handles and cache them for future sharing. + h, err := newHandles(ctx, i.s.client, i.file, flags) if err != nil { return nil, false, err } @@ -239,7 +245,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle if !flags.Read { // Writer can't be used for read, must create a new handle. var err error - h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true}) + h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true}) if err != nil { return err } @@ -268,7 +274,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle // operations on the old will see the new data. Then, make the new handle take // ownereship of the old FD and mark the old readHandle to not close the FD // when done. - if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil { + if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil { return err } diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 8c17603f8..c09f3b71c 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -234,6 +234,8 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, if err != nil { return nil, err } + // We're not going to use newFile after return. + defer newFile.close(ctx) // Stabilize the endpoint map while creation is in progress. unlock := i.session().endpoints.lock() @@ -254,7 +256,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // Get the attributes of the file to create inode key. qid, mask, attr, err := getattr(ctx, newFile) if err != nil { - newFile.close(ctx) return nil, err } @@ -270,7 +271,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string, // cloned and re-opened multiple times after creation. _, unopened, err := i.fileState.file.walk(ctx, []string{name}) if err != nil { - newFile.close(ctx) return nil, err } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 0da608548..4e358a46a 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -143,9 +143,9 @@ type session struct { // socket files. This allows unix domain sockets to be used with paths that // belong to a gofer. // - // TODO(b/77154739): there are few possible races with someone stat'ing the - // file and another deleting it concurrently, where the file will not be - // reported as socket file. + // TODO(gvisor.dev/issue/1200): there are few possible races with someone + // stat'ing the file and another deleting it concurrently, where the file + // will not be reported as socket file. endpoints *endpointMaps `state:"wait"` } diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 1cbed07ae..23daeb528 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -21,6 +21,8 @@ go_library( "socket_unsafe.go", "tty.go", "util.go", + "util_amd64_unsafe.go", + "util_arm64_unsafe.go", "util_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go index bad61a9a1..e37e687c6 100644 --- a/pkg/sentry/fs/host/util.go +++ b/pkg/sentry/fs/host/util.go @@ -155,7 +155,7 @@ func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr { AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec), ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec), StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec), - Links: s.Nlink, + Links: uint64(s.Nlink), } } diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go new file mode 100644 index 000000000..66da6e9f5 --- /dev/null +++ b/pkg/sentry/fs/host/util_amd64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_NEWFSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go new file mode 100644 index 000000000..e8cb94aeb --- /dev/null +++ b/pkg/sentry/fs/host/util_arm64_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package host + +import ( + "syscall" + "unsafe" +) + +func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { + var stat syscall.Stat_t + namePtr, err := syscall.BytePtrFromString(name) + if err != nil { + return stat, err + } + _, _, errno := syscall.Syscall6( + syscall.SYS_FSTATAT, + uintptr(fd), + uintptr(unsafe.Pointer(namePtr)), + uintptr(unsafe.Pointer(&stat)), + uintptr(flags), + 0, 0) + if errno != 0 { + return stat, errno + } + return stat, nil +} diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go index 2b76f1065..3ab36b088 100644 --- a/pkg/sentry/fs/host/util_unsafe.go +++ b/pkg/sentry/fs/host/util_unsafe.go @@ -116,22 +116,3 @@ func setTimestamps(fd int, ts fs.TimeSpec) error { } return nil } - -func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) { - var stat syscall.Stat_t - namePtr, err := syscall.BytePtrFromString(name) - if err != nil { - return stat, err - } - _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, - uintptr(fd), - uintptr(unsafe.Pointer(namePtr)), - uintptr(unsafe.Pointer(&stat)), - uintptr(flags), - 0, 0) - if errno != 0 { - return stat, errno - } - return stat, nil -} diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index f4ddfa406..91e2fde2f 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -270,6 +270,14 @@ func (i *Inode) Getxattr(name string) (string, error) { return i.InodeOperations.Getxattr(i, name) } +// Setxattr calls i.InodeOperations.Setxattr with i as the Inode. +func (i *Inode) Setxattr(name, value string) error { + if i.overlay != nil { + return overlaySetxattr(i.overlay, name, value) + } + return i.InodeOperations.Setxattr(i, name, value) +} + // Listxattr calls i.InodeOperations.Listxattr with i as the Inode. func (i *Inode) Listxattr() (map[string]struct{}, error) { if i.overlay != nil { @@ -344,6 +352,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error // Truncate calls i.InodeOperations.Truncate with i as the Inode. func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error { + if IsDir(i.StableAttr) { + return syserror.EISDIR + } + if i.overlay != nil { return overlayTruncate(ctx, i.overlay, d, size) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 5a388dad1..13d11e001 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -436,7 +436,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena } func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { - if err := copyUp(ctx, parent); err != nil { + if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err } @@ -462,7 +462,9 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri inode.DecRef() return nil, err } - return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil + // Use the parent's MountSource, since that corresponds to the overlay, + // and not the upper filesystem. + return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil } func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint { @@ -550,6 +552,11 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) { return s, err } +// TODO(b/146028302): Support setxattr for overlayfs. +func overlaySetxattr(o *overlayEntry, name, value string) error { + return syserror.EOPNOTSUPP +} + func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) { o.copyMu.RLock() defer o.copyMu.RUnlock() diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 1d3ff39e0..25573e986 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syncutil" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/third_party/gvsync" ) // The virtual filesystem implements an overlay configuration. For a high-level @@ -199,7 +199,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu gvsync.DowngradableRWMutex `state:"nosave"` + dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 497475db5..75cbb0622 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -52,6 +52,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/tcpip/header", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index f70239449..402919924 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "io" + "reflect" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -33,6 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // newNet creates a new proc net entry. @@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a @@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo // (ClockGetres returns 1ns resolution). "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), - "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), @@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netSnmp implements seqfile.SeqSource for /proc/net/snmp. +// +// +stateify savable +type netSnmp struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netSnmp) NeedsUpdate(generation int64) bool { + return true +} + +type snmpLine struct { + prefix string + header string +} + +var snmp = []snmpLine{ + { + prefix: "Ip", + header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", + }, + { + prefix: "Icmp", + header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", + }, + { + prefix: "IcmpMsg", + }, + { + prefix: "Tcp", + header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", + }, + { + prefix: "Udp", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, + { + prefix: "UdpLite", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, +} + +func toSlice(a interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(a)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + +func sprintSlice(s []uint64) string { + if len(s) == 0 { + return "" + } + r := fmt.Sprint(s) + return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's +// net/core/net-procfs.c:dev_seq_show. +func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + contents := make([]string, 0, len(snmp)*2) + types := []interface{}{ + &inet.StatSNMPIP{}, + &inet.StatSNMPICMP{}, + nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. + &inet.StatSNMPTCP{}, + &inet.StatSNMPUDP{}, + &inet.StatSNMPUDPLite{}, + } + for i, stat := range types { + line := snmp[i] + if stat == nil { + contents = append( + contents, + fmt.Sprintf("%s:\n", line.prefix), + fmt.Sprintf("%s:\n", line.prefix), + ) + continue + } + if err := n.s.Statistics(stat, line.prefix); err != nil { + if err == syserror.EOPNOTSUPP { + log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } else { + log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } + } + var values string + if line.prefix == "Tcp" { + tcp := stat.(*inet.StatSNMPTCP) + // "Tcp" needs special processing because MaxConn is signed. RFC 2012. + values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + } else { + values = sprintSlice(toSlice(stat)) + } + contents = append( + contents, + fmt.Sprintf("%s: %s\n", line.prefix, line.header), + fmt.Sprintf("%s: %s\n", line.prefix, values), + ) + } + + data := make([]seqfile.SeqData, 0, len(snmp)*2) + for _, l := range contents { + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)}) + } + + return data, 0 +} + +// netRoute implements seqfile.SeqSource for /proc/net/route. +// +// +stateify savable +type netRoute struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netRoute) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. +func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + interfaces := n.s.Interfaces() + contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"} + for _, rt := range n.s.RouteTable() { + // /proc/net/route only includes ipv4 routes. + if rt.Family != linux.AF_INET { + continue + } + + // /proc/net/route does not include broadcast or multicast routes. + if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { + continue + } + + iface, ok := interfaces[rt.OutputInterface] + if !ok || iface.Name == "lo" { + continue + } + + var ( + gw uint32 + prefix uint32 + flags = linux.RTF_UP + ) + if len(rt.GatewayAddr) == header.IPv4AddressSize { + flags |= linux.RTF_GATEWAY + gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + } + if len(rt.DstAddr) == header.IPv4AddressSize { + prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + } + l := fmt.Sprintf( + "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", + iface.Name, + prefix, + gw, + flags, + 0, // RefCnt. + 0, // Use. + 0, // Metric. + (uint32(1)<<rt.DstLen)-1, + 0, // MTU. + 0, // Window. + 0, // RTT. + ) + contents = append(contents, l) + } + + var data []seqfile.SeqData + for _, l := range contents { + l = fmt.Sprintf("%-127s\n", l) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)}) + } + + return data, 0 +} + // netUnix implements seqfile.SeqSource for /proc/net/unix. // // +stateify savable diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f3b63dfc2..bd93f83fa 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -64,7 +64,7 @@ var _ fs.InodeOperations = (*tcpMemInode)(nil) func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode { tm := &tcpMemInode{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), s: s, dir: dir, } @@ -77,6 +77,11 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir return fs.NewInode(ctx, tm, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true @@ -168,14 +173,15 @@ func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error { // +stateify savable type tcpSack struct { + fsutil.SimpleFileInode + stack inet.Stack `state:"wait"` enabled *bool - fsutil.SimpleFileInode } func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { ts := &tcpSack{ - SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), stack: s, } sattr := fs.StableAttr{ @@ -187,6 +193,11 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f return fs.NewInode(ctx, ts, msrc, sattr) } +// Truncate implements fs.InodeOperations.Truncate. +func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 87184ec67..9bf4b4527 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -67,29 +67,28 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode { +func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - // FIXME(b/123511468): create the correct io file for threads. - "io": newIO(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "ns": newNamespaceDir(t, msrc), "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, showSubtasks, p.pidns), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), "statm": newStatm(t, msrc), "status": newStatus(t, msrc, p.pidns), "uid_map": newUIDMap(t, msrc), } - if showSubtasks { + if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) } if len(p.cgroupControllers) > 0 { @@ -605,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(&buf, "Mems_allowed:\t1\n") + fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n") return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0 } @@ -619,8 +622,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) +func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { + if isThreadGroup { + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + } + return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -639,7 +645,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se io.Accumulate(i.IOUsage()) var buf bytes.Buffer - fmt.Fprintf(&buf, "char: %d\n", io.CharsRead) + fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead) fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten) fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls) fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 19b7557d5..6b07f6bf2 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -76,6 +76,11 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn func (mi *masterInodeOperations) Release(ctx context.Context) { } +// Truncate implements fs.InodeOperations.Truncate. +func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // It allocates a new terminal. diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index 944c4ada1..2a51e6bab 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -72,6 +72,11 @@ func (si *slaveInodeOperations) Release(ctx context.Context) { si.t.DecRef() } +// Truncate implements fs.InodeOperations.Truncate. +func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + // GetFile implements fs.InodeOperations.GetFile. // // This may race with destruction of the terminal. If the terminal is gone, it diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ff8138820..917f90cc0 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -53,8 +53,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal d: d, n: n, ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{}, - slaveKTTY: &kernel.TTY{}, + masterKTTY: &kernel.TTY{Index: n}, + slaveKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7ccff8b0d..880b7bcd3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -38,6 +38,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/fd", + "//pkg/fspath", "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 10a8083a0..177ce2cb9 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -50,7 +50,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil { + if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } return func() { diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index cea89bcd9..a2d8c3ad6 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -154,7 +154,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff)) if n < toRead { return n, syserror.EIO } @@ -174,7 +174,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { var childPhyBlk uint32 - err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) + err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 213aa3919..181727ef7 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -87,12 +87,14 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) regFile := regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: getMockBMFileFize(), }, }, - dev: bytes.NewReader(mockDisk), blkSize: uint64(mockBMBlkSize), }, } diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 054fb42b6..a080cb189 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -41,16 +41,18 @@ func newDentry(in *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { - d.inode.decRef(vfsfs.Impl().(*filesystem)) +func (d *dentry) DecRef() { + // FIXME(b/134676337): filesystem.mu may not be locked as required by + // inode.decRef(). + d.inode.decRef() } diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index f10accafc..4b7d17dc6 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -40,14 +40,14 @@ var _ vfs.FilesystemType = (*FilesystemType)(nil) // Currently there are two ways of mounting an ext(2/3/4) fs: // 1. Specify a mount with our internal special MountType in the OCI spec. // 2. Expose the device to the container and mount it from application layer. -func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) { +func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) { if opts.InternalData == nil { // User mount call. // TODO(b/134676337): Open the device specified by `source` and return that. panic("unimplemented") } - // NewFilesystem call originated from within the sentry. + // GetFilesystem call originated from within the sentry. devFd, ok := opts.InternalData.(int) if !ok { return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device") @@ -91,8 +91,8 @@ func isCompatible(sb disklayout.SuperBlock) bool { return true } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { // TODO(b/134676337): Ensure that the user is mounting readonly. If not, // EACCESS should be returned according to mount(2). Filesystem independent // flags (like readonly) are currently not available in pkg/sentry/vfs. @@ -103,7 +103,7 @@ func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials } fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)} - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) fs.sb, err = readSuperBlock(dev) if err != nil { return nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 1aa2bd6a4..e9f756732 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -66,7 +66,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -147,55 +147,54 @@ func TestSeek(t *testing.T) { t.Fatalf("vfsfs.OpenAt failed: %v", err) } - if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { + if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil { t.Errorf("expected seek position 0, got %d and error %v", n, err) } - stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{}) + stat, err := fd.Stat(ctx, vfs.StatOptions{}) if err != nil { t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err) } // We should be able to seek beyond the end of file. size := int64(stat.Size) - if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { + if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } - if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { + if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err) } // Make sure negative offsets work with SEEK_CUR. - if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } // Make sure SEEK_END works with regular files. - switch fd.Impl().(type) { - case *regularFileFD: + if _, ok := fd.Impl().(*regularFileFD); ok { // Seek back to 0. - if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { + if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", 0, n, err) } // Seek forward beyond EOF. - if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { + if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil { t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err) } // EINVAL should be returned if the resulting offset is negative. - if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { + if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL { t.Errorf("expected error EINVAL but got %v", err) } } @@ -456,7 +455,7 @@ func TestRead(t *testing.T) { want := make([]byte, 1) for { n, err := f.Read(want) - fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) + fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}) if diff := cmp.Diff(got, want); diff != "" { t.Errorf("file data mismatch (-want +got):\n%s", diff) @@ -464,7 +463,7 @@ func TestRead(t *testing.T) { // Make sure there is no more file data left after getting EOF. if n == 0 || err == io.EOF { - if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { + if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 { t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image) } @@ -509,27 +508,27 @@ func TestIterDirents(t *testing.T) { } wantDirents := []vfs.Dirent{ - vfs.Dirent{ + { Name: ".", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "..", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "lost+found", Type: linux.DT_DIR, }, - vfs.Dirent{ + { Name: "file.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "bigfile.txt", Type: linux.DT_REG, }, - vfs.Dirent{ + { Name: "symlink.txt", Type: linux.DT_LNK, }, @@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) { } cb := &iterDirentsCb{} - if err = fd.Impl().IterDirents(ctx, cb); err != nil { + if err = fd.IterDirents(ctx, cb); err != nil { t.Fatalf("dir fd.IterDirents() failed: %v", err) } diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 38b68a2d3..3d3ebaca6 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -99,7 +99,7 @@ func (f *extentFile) buildExtTree() error { func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) { var header disklayout.ExtentHeader off := entry.PhysicalBlock() * f.regFile.inode.blkSize - err := readFromDisk(f.regFile.inode.dev, int64(off), &header) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header) if err != nil { return nil, err } @@ -115,7 +115,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla curEntry = &disklayout.ExtentIdx{} } - err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry) + err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry) if err != nil { return nil, err } @@ -229,7 +229,7 @@ func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byt toRead = len(dst) } - n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart)) + n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart)) if n < toRead { return n, syserror.EIO } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index 42d0a484b..a2382daa3 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -180,13 +180,15 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] mockExtentFile := &extentFile{ regFile: regularFile{ inode: inode{ + fs: &filesystem{ + dev: bytes.NewReader(mockDisk), + }, diskInode: &disklayout.InodeNew{ InodeOld: disklayout.InodeOld{ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root), }, }, blkSize: mockExtentBlkSize, - dev: bytes.NewReader(mockDisk), }, }, } diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index 4d18b28cb..841274daf 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -26,33 +26,14 @@ import ( type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - // flags is the same as vfs.OpenOptions.Flags which are passed to - // vfs.FilesystemImpl.OpenAt. - // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2), - // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set. - // Only close(2), fstat(2), fstatfs(2) should work. - flags uint32 } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode -} - -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil + return fd.vfsfd.Dentry().Impl().(*dentry).inode } // Stat implements vfs.FileDescriptionImpl.Stat. diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 2d15e8aaf..d7e87979a 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -20,6 +20,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -441,3 +442,46 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return nil, err + } + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + _, _, err := fs.walk(rp, false) + if err != nil { + return "", err + } + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + _, _, err := fs.walk(rp, false) + if err != nil { + return err + } + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index e6c847a71..b2cc826c7 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -16,7 +16,6 @@ package ext import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -42,13 +41,13 @@ type inode struct { // refs is a reference count. refs is accessed using atomic memory operations. refs int64 + // fs is the containing filesystem. + fs *filesystem + // inodeNum is the inode number of this inode on disk. This is used to // identify inodes within the ext filesystem. inodeNum uint32 - // dev represents the underlying device. Same as filesystem.dev. - dev io.ReaderAt - // blkSize is the fs data block size. Same as filesystem.sb.BlockSize(). blkSize uint64 @@ -81,10 +80,10 @@ func (in *inode) tryIncRef() bool { // decRef decrements the inode ref count and releases the inode resources if // the ref count hits 0. // -// Precondition: Must have locked fs.mu. -func (in *inode) decRef(fs *filesystem) { +// Precondition: Must have locked filesystem.mu. +func (in *inode) decRef() { if refs := atomic.AddInt64(&in.refs, -1); refs == 0 { - delete(fs.inodeCache, in.inodeNum) + delete(in.fs.inodeCache, in.inodeNum) } else if refs < 0 { panic("ext.inode.decRef() called without holding a reference") } @@ -117,8 +116,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) { // Build the inode based on its type. inode := inode{ + fs: fs, inodeNum: inodeNum, - dev: fs.dev, blkSize: blkSize, diskInode: diskInode, } @@ -154,11 +153,13 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v if err := in.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } + mnt := rp.Mount() switch in.impl.(type) { case *regularFile: var fd regularFileFD - fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *directory: // Can't open directories writably. This check is not necessary for a read @@ -167,8 +168,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) - fd.flags = flags + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *symlink: if flags&linux.O_PATH == 0 { @@ -176,8 +178,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v return nil, syserror.ELOOP } var fd symlinkFD - fd.flags = flags - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD new file mode 100644 index 000000000..52596c090 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -0,0 +1,60 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "slot_list", + out = "slot_list.go", + package = "kernfs", + prefix = "slot", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*slot", + "Linker": "*slot", + }, +) + +go_library( + name = "kernfs", + srcs = [ + "dynamic_bytes_file.go", + "fd_impl_util.go", + "filesystem.go", + "inode_impl_util.go", + "kernfs.go", + "slot_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/fspath", + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/context", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) + +go_test( + name = "kernfs_test", + size = "small", + srcs = ["kernfs_test.go"], + deps = [ + ":kernfs", + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go new file mode 100644 index 000000000..51102ce48 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -0,0 +1,117 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// DynamicBytesFile implements kernfs.Inode and represents a read-only +// file whose contents are backed by a vfs.DynamicBytesSource. +// +// Must be initialized with Init before first use. +type DynamicBytesFile struct { + InodeAttrs + InodeNoopRefCount + InodeNotDirectory + InodeNotSymlink + + data vfs.DynamicBytesSource +} + +// Init intializes a dynamic bytes file. +func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) { + f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444) + f.data = data +} + +// Open implements Inode.Open. +func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &DynamicBytesFD{} + fd.Init(rp.Mount(), vfsd, f.data, flags) + return &fd.vfsfd, nil +} + +// SetStat implements Inode.SetStat. +func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} + +// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a +// DynamicBytesFile. +// +// Must be initialized with Init before first use. +type DynamicBytesFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DynamicBytesFileDescriptionImpl + + vfsfd vfs.FileDescription + inode Inode +} + +// Init initializes a DynamicBytesFD. +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.inode = d.Impl().(*Dentry).inode + fd.SetDataSource(data) + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.FileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *DynamicBytesFD) Release() {} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { + // DynamicBytesFiles are immutable. + return syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go new file mode 100644 index 000000000..bd402330f --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -0,0 +1,193 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory +// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not +// compatible with dynamic directories. +// +// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling +// IterDirents callback. The IterDirents callback therefore cannot hash or +// unhash children, or recursively call IterDirents on the same underlying +// inode. +// +// Must be initialize with Init before first use. +type GenericDirectoryFD struct { + vfs.FileDescriptionDefaultImpl + vfs.DirectoryFileDescriptionDefaultImpl + + vfsfd vfs.FileDescription + children *OrderedChildren + off int64 +} + +// Init initializes a GenericDirectoryFD. +func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) { + m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref. + fd.children = children + fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}) +} + +// VFSFileDescription returns a pointer to the vfs.FileDescription representing +// this object. +func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription { + return &fd.vfsfd +} + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts) +} + +// Read implmenets vfs.FileDescriptionImpl.Read. +func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts) +} + +// PRead implmenets vfs.FileDescriptionImpl.PRead. +func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) +} + +// Release implements vfs.FileDecriptionImpl.Release. +func (fd *GenericDirectoryFD) Release() {} + +func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *GenericDirectoryFD) inode() Inode { + return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode +} + +// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// o.mu when calling cb. +func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + vfsFS := fd.filesystem() + fs := vfsFS.Impl().(*Filesystem) + vfsd := fd.vfsfd.VirtualDentry().Dentry() + + fs.mu.Lock() + defer fs.mu.Unlock() + + // Handle ".". + if fd.off == 0 { + stat := fd.inode().Stat(vfsFS) + dirent := vfs.Dirent{ + Name: ".", + Type: linux.DT_DIR, + Ino: stat.Ino, + NextOff: 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle "..". + if fd.off == 1 { + parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode + stat := parentInode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: "..", + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: 2, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + // Handle static children. + fd.children.mu.RLock() + defer fd.children.mu.RUnlock() + // fd.off accounts for "." and "..", but fd.children do not track + // these. + childIdx := fd.off - 2 + for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { + inode := it.Dentry.Impl().(*Dentry).inode + stat := inode.Stat(vfsFS) + dirent := vfs.Dirent{ + Name: it.Name, + Type: linux.FileMode(stat.Mode).DirentType(), + Ino: stat.Ino, + NextOff: fd.off + 1, + } + if !cb.Handle(dirent) { + return nil + } + fd.off++ + } + + return nil +} + +// Seek implements vfs.FileDecriptionImpl.Seek. +func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fs := fd.filesystem().Impl().(*Filesystem) + fs.mu.Lock() + defer fs.mu.Unlock() + + switch whence { + case linux.SEEK_SET: + // Use offset as given. + case linux.SEEK_CUR: + offset += fd.off + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(fs), nil +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return inode.SetStat(fs, opts) +} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go new file mode 100644 index 000000000..3cbbe4b20 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -0,0 +1,743 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements vfs.FilesystemImpl for kernfs. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// stepExistingLocked resolves rp.Component() in parent directory vfsd. +// +// stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postcondition: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + // Directory searchable? + if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } +afterSymlink: + d.dirMu.Lock() + nextVFSD, err := rp.ResolveComponent(vfsd) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + if nextVFSD != nil { + // Cached dentry exists, revalidate. + next := nextVFSD.Impl().(*Dentry) + if !next.inode.Valid(ctx) { + d.dirMu.Lock() + rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD) + d.dirMu.Unlock() + fs.deferDecRef(nextVFSD) // Reference from Lookup. + nextVFSD = nil + } + } + if nextVFSD == nil { + // Dentry isn't cached; it either doesn't exist or failed + // revalidation. Attempt to resolve it via Lookup. + name := rp.Component() + var err error + nextVFSD, err = d.inode.Lookup(ctx, name) + // Reference on nextVFSD dropped by a corresponding Valid. + if err != nil { + return nil, err + } + d.InsertChild(name, nextVFSD) + } + next := nextVFSD.Impl().(*Dentry) + + // Resolve any symlink at current path component. + if rp.ShouldFollowSymlink() && d.isSymlink() { + // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks". + target, err := next.inode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + goto afterSymlink + + } + rp.Advance() + return nextVFSD, nil +} + +// walkExistingLocked resolves rp to an existing file. +// +// walkExistingLocked is loosely analogous to Linux's +// fs/namei.c:path_lookupat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Done() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if rp.MustBeDir() && !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// walkParentDirLocked resolves all but the last path component of rp to an +// existing directory. It does not check that the returned directory is +// searchable by the provider of rp. +// +// walkParentDirLocked is loosely analogous to Linux's +// fs/namei.c:path_parentat(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// +// Postconditions: Caller must call fs.processDeferredDecRefs*. +func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { + vfsd := rp.Start() + for !rp.Final() { + var err error + vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd) + if err != nil { + return nil, nil, err + } + } + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, nil, syserror.ENOTDIR + } + return vfsd, d.inode, nil +} + +// checkCreateLocked checks that a file named rp.Component() may be created in +// directory parentVFSD, then returns rp.Component(). +// +// Preconditions: Filesystem.mu must be locked for at least reading. parentInode +// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. +func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return "", err + } + pc := rp.Component() + if pc == "." || pc == ".." { + return "", syserror.EEXIST + } + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return "", err + } + if childVFSD != nil { + return "", syserror.EEXIST + } + if parentVFSD.IsDisowned() { + return "", syserror.ENOENT + } + return pc, nil +} + +// checkDeleteLocked checks that the file represented by vfsd may be deleted. +// +// Preconditions: Filesystem.mu must be locked for at least reading. +func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { + parentVFSD := vfsd.Parent() + if parentVFSD == nil { + return syserror.EBUSY + } + if parentVFSD.IsDisowned() { + return syserror.ENOENT + } + if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + return nil +} + +// checkRenameLocked checks that a rename operation may be performed on the +// target dentry across the given set of parent directories. The target dentry +// may be nil. +// +// Precondition: isDir(dstInode) == true. +func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error { + srcDir := src.Parent() + if srcDir == nil { + return syserror.EBUSY + } + if srcDir.IsDisowned() { + return syserror.ENOENT + } + if dstDir.IsDisowned() { + return syserror.ENOENT + } + // Check for creation permissions on dst dir. + if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + + return nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *Filesystem) Release() { +} + +// Sync implements vfs.FilesystemImpl.Sync. +func (fs *Filesystem) Sync(ctx context.Context) error { + // All filesystem state is in-memory. + return nil +} + +// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. +func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + + if opts.CheckSearchable { + d := vfsd.Impl().(*Dentry) + if !d.isDir() { + return nil, syserror.ENOTDIR + } + if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + } + vfsd.IncRef() // Ownership transferred to caller. + return vfsd, nil +} + +// LinkAt implements vfs.FilesystemImpl.LinkAt. +func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if rp.Mount() != vd.Mount() { + return syserror.EXDEV + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + + d := vd.Dentry().Impl().(*Dentry) + if d.isDir() { + return syserror.EPERM + } + + child, err := parentInode.NewLink(ctx, pc, d.inode) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MkdirAt implements vfs.FilesystemImpl.MkdirAt. +func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewDir(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// MknodAt implements vfs.FilesystemImpl.MknodAt. +func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + new, err := parentInode.NewNode(ctx, pc, opts) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, new) + return nil +} + +// OpenAt implements vfs.FilesystemImpl.OpenAt. +func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // Filter out flags that are not supported by kernfs. O_DIRECTORY and + // O_NOFOLLOW have no effect here (they're handled by VFS by setting + // appropriate bits in rp), but are returned by + // FileDescriptionImpl.StatusFlags(). + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + ats := vfs.AccessTypesForOpenFlags(opts.Flags) + + // Do not create new file. + if opts.Flags&linux.O_CREAT == 0 { + fs.mu.RLock() + defer fs.processDeferredDecRefs() + defer fs.mu.RUnlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + if err != nil { + return nil, err + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } + + // May create new file. + mustCreate := opts.Flags&linux.O_EXCL != 0 + vfsd := rp.Start() + inode := vfsd.Impl().(*Dentry).inode + fs.mu.Lock() + defer fs.mu.Unlock() + if rp.Done() { + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + if mustCreate { + return nil, syserror.EEXIST + } + if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return inode.Open(rp, vfsd, opts.Flags) + } +afterTrailingSymlink: + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return nil, err + } + // Check for search permission in the parent directory. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + // Reject attempts to open directories with O_CREAT. + if rp.MustBeDir() { + return nil, syserror.EISDIR + } + pc := rp.Component() + if pc == "." || pc == ".." { + return nil, syserror.EISDIR + } + // Determine whether or not we need to create a file. + childVFSD, err := rp.ResolveChild(parentVFSD, pc) + if err != nil { + return nil, err + } + if childVFSD == nil { + // Already checked for searchability above; now check for writability. + if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return nil, err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + defer rp.Mount().EndWrite() + // Create and open the child. + child, err := parentInode.NewFile(ctx, pc, opts) + if err != nil { + return nil, err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags) + } + // Open existing file or follow symlink. + if mustCreate { + return nil, syserror.EEXIST + } + childDentry := childVFSD.Impl().(*Dentry) + childInode := childDentry.inode + if rp.ShouldFollowSymlink() { + if childDentry.isSymlink() { + target, err := childInode.Readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + // rp.Final() may no longer be true since we now need to resolve the + // symlink target. + goto afterTrailingSymlink + } + } + if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + return childInode.Open(rp, childVFSD, opts.Flags) +} + +// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. +func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { + fs.mu.RLock() + d, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + if !d.Impl().(*Dentry).isSymlink() { + return "", syserror.EINVAL + } + return inode.Readlink(ctx) +} + +// RenameAt implements vfs.FilesystemImpl.RenameAt. +func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error { + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + exchange := opts.Flags&linux.RENAME_EXCHANGE != 0 + whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0 + if exchange && (noReplace || whiteout) { + // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE. + return syserror.EINVAL + } + if exchange || whiteout { + // Exchange and Whiteout flags are not supported on kernfs. + return syserror.EINVAL + } + + fs.mu.Lock() + defer fs.mu.Lock() + + mnt := rp.Mount() + if mnt != vd.Mount() { + return syserror.EXDEV + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + + dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + + srcVFSD := vd.Dentry() + srcDirVFSD := srcVFSD.Parent() + + // Can we remove the src dentry? + if err := checkDeleteLocked(rp, srcVFSD); err != nil { + return err + } + + // Can we create the dst dentry? + var dstVFSD *vfs.Dentry + pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode) + switch err { + case nil: + // Ok, continue with rename as replacement. + case syserror.EEXIST: + if noReplace { + // Won't overwrite existing node since RENAME_NOREPLACE was requested. + return syserror.EEXIST + } + dstVFSD, err = rp.ResolveChild(dstDirVFSD, pc) + if err != nil { + panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) + } + default: + return err + } + + mntns := vfs.MountNamespaceFromContext(ctx) + virtfs := rp.VirtualFilesystem() + + srcDirDentry := srcDirVFSD.Impl().(*Dentry) + dstDirDentry := dstDirVFSD.Impl().(*Dentry) + + // We can't deadlock here due to lock ordering because we're protected from + // concurrent renames by fs.mu held for writing. + srcDirDentry.dirMu.Lock() + defer srcDirDentry.dirMu.Unlock() + dstDirDentry.dirMu.Lock() + defer dstDirDentry.dirMu.Unlock() + + if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { + return err + } + srcDirInode := srcDirDentry.inode + replaced, err := srcDirInode.Rename(ctx, srcVFSD.Name(), pc, srcVFSD, dstDirVFSD) + if err != nil { + virtfs.AbortRenameDentry(srcVFSD, dstVFSD) + return err + } + virtfs.CommitRenameReplaceDentry(srcVFSD, dstDirVFSD, pc, replaced) + return nil +} + +// RmdirAt implements vfs.FilesystemImpl.RmdirAt. +func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if !vfsd.Impl().(*Dentry).isDir() { + return syserror.ENOTDIR + } + if inode.HasChildren() { + return syserror.ENOTEMPTY + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// SetStatAt implements vfs.FilesystemImpl.SetStatAt. +func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + if opts.Stat.Mask == 0 { + return nil + } + return inode.SetStat(fs.VFSFilesystem(), opts) +} + +// StatAt implements vfs.FilesystemImpl.StatAt. +func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { + fs.mu.RLock() + _, inode, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statx{}, err + } + return inode.Stat(fs.VFSFilesystem()), nil +} + +// StatFSAt implements vfs.FilesystemImpl.StatFSAt. +func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return linux.Statfs{}, err + } + // TODO: actually implement statfs + return linux.Statfs{}, syserror.ENOSYS +} + +// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. +func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { + if rp.Done() { + return syserror.EEXIST + } + fs.mu.Lock() + defer fs.mu.Unlock() + parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + child, err := parentInode.NewSymlink(ctx, pc, target) + if err != nil { + return err + } + parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + return nil +} + +// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt. +func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { + fs.mu.Lock() + defer fs.mu.Unlock() + vfsd, _, err := fs.walkExistingLocked(ctx, rp) + fs.processDeferredDecRefsLocked() + if err != nil { + return err + } + if err := rp.Mount().CheckBeginWrite(); err != nil { + return err + } + defer rp.Mount().EndWrite() + if err := checkDeleteLocked(rp, vfsd); err != nil { + return err + } + if vfsd.Impl().(*Dentry).isDir() { + return syserror.EISDIR + } + virtfs := rp.VirtualFilesystem() + parentDentry := vfsd.Parent().Impl().(*Dentry) + parentDentry.dirMu.Lock() + defer parentDentry.dirMu.Unlock() + if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil { + return err + } + if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + virtfs.AbortDeleteDentry(vfsd) + return err + } + virtfs.CommitDeleteDentry(vfsd) + return nil +} + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return nil, err + } + // kernfs currently does not support extended attributes. + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return "", err + } + // kernfs currently does not support extended attributes. + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + _, _, err := fs.walkExistingLocked(ctx, rp) + fs.mu.RUnlock() + fs.processDeferredDecRefs() + if err != nil { + return err + } + // kernfs currently does not support extended attributes. + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go new file mode 100644 index 000000000..7b45b702a --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -0,0 +1,492 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// InodeNoopRefCount partially implements the Inode interface, specifically the +// inodeRefs sub interface. InodeNoopRefCount implements a simple reference +// count for inodes, performing no extra actions when references are obtained or +// released. This is suitable for simple file inodes that don't reference any +// resources. +type InodeNoopRefCount struct { +} + +// IncRef implements Inode.IncRef. +func (n *InodeNoopRefCount) IncRef() { +} + +// DecRef implements Inode.DecRef. +func (n *InodeNoopRefCount) DecRef() { +} + +// TryIncRef implements Inode.TryIncRef. +func (n *InodeNoopRefCount) TryIncRef() bool { + return true +} + +// Destroy implements Inode.Destroy. +func (n *InodeNoopRefCount) Destroy() { +} + +// InodeDirectoryNoNewChildren partially implements the Inode interface. +// InodeDirectoryNoNewChildren represents a directory inode which does not +// support creation of new children. +type InodeDirectoryNoNewChildren struct{} + +// NewFile implements Inode.NewFile. +func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewLink implements Inode.NewLink. +func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// InodeNotDirectory partially implements the Inode interface, specifically the +// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not +// represent directories can embed this to provide no-op implementations for +// directory-related functions. +type InodeNotDirectory struct { +} + +// HasChildren implements Inode.HasChildren. +func (*InodeNotDirectory) HasChildren() bool { + return false +} + +// NewFile implements Inode.NewFile. +func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { + panic("NewFile called on non-directory inode") +} + +// NewDir implements Inode.NewDir. +func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { + panic("NewDir called on non-directory inode") +} + +// NewLink implements Inode.NewLinkink. +func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { + panic("NewLink called on non-directory inode") +} + +// NewSymlink implements Inode.NewSymlink. +func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + panic("NewSymlink called on non-directory inode") +} + +// NewNode implements Inode.NewNode. +func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + panic("NewNode called on non-directory inode") +} + +// Unlink implements Inode.Unlink. +func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { + panic("Unlink called on non-directory inode") +} + +// RmDir implements Inode.RmDir. +func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { + panic("RmDir called on non-directory inode") +} + +// Rename implements Inode.Rename. +func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { + panic("Rename called on non-directory inode") +} + +// Lookup implements Inode.Lookup. +func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + panic("Lookup called on non-directory inode") +} + +// Valid implements Inode.Valid. +func (*InodeNotDirectory) Valid(context.Context) bool { + return true +} + +// InodeNoDynamicLookup partially implements the Inode interface, specifically +// the inodeDynamicLookup sub interface. Directory inodes that do not support +// dymanic entries (i.e. entries that are not "hashed" into the +// vfs.Dentry.children) can embed this to provide no-op implementations for +// functions related to dynamic entries. +type InodeNoDynamicLookup struct{} + +// Lookup implements Inode.Lookup. +func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + return nil, syserror.ENOENT +} + +// Valid implements Inode.Valid. +func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool { + return true +} + +// InodeNotSymlink partially implements the Inode interface, specifically the +// inodeSymlink sub interface. All inodes that are not symlinks may embed this +// to return the appropriate errors from symlink-related functions. +type InodeNotSymlink struct{} + +// Readlink implements Inode.Readlink. +func (*InodeNotSymlink) Readlink(context.Context) (string, error) { + return "", syserror.EINVAL +} + +// InodeAttrs partially implements the Inode interface, specifically the +// inodeMetadata sub interface. InodeAttrs provides functionality related to +// inode attributes. +// +// Must be initialized by Init prior to first use. +type InodeAttrs struct { + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 +} + +// Init initializes this InodeAttrs. +func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) { + if mode.FileType() == 0 { + panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) + } + + nlink := uint32(1) + if mode.FileType() == linux.ModeDirectory { + nlink = 2 + } + atomic.StoreUint64(&a.ino, ino) + atomic.StoreUint32(&a.mode, uint32(mode)) + atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) + atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) + atomic.StoreUint32(&a.nlink, nlink) +} + +// Mode implements Inode.Mode. +func (a *InodeAttrs) Mode() linux.FileMode { + return linux.FileMode(atomic.LoadUint32(&a.mode)) +} + +// Stat partially implements Inode.Stat. Note that this function doesn't provide +// all the stat fields, and the embedder should consider extending the result +// with filesystem-specific fields. +func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx { + var stat linux.Statx + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Ino = atomic.LoadUint64(&a.ino) + stat.Mode = uint16(a.Mode()) + stat.UID = atomic.LoadUint32(&a.uid) + stat.GID = atomic.LoadUint32(&a.gid) + stat.Nlink = atomic.LoadUint32(&a.nlink) + + // TODO: Implement other stat fields like timestamps. + + return stat +} + +// SetStat implements Inode.SetStat. +func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error { + stat := opts.Stat + if stat.Mask&linux.STATX_MODE != 0 { + for { + old := atomic.LoadUint32(&a.mode) + new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT)) + if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped { + break + } + } + } + + if stat.Mask&linux.STATX_UID != 0 { + atomic.StoreUint32(&a.uid, stat.UID) + } + if stat.Mask&linux.STATX_GID != 0 { + atomic.StoreUint32(&a.gid, stat.GID) + } + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. + + // TODO: Implement other stat fields like timestamps. + + return nil +} + +// CheckPermissions implements Inode.CheckPermissions. +func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { + mode := a.Mode() + return vfs.GenericCheckPermissions( + creds, + ats, + mode.FileType() == linux.ModeDirectory, + uint16(mode), + auth.KUID(atomic.LoadUint32(&a.uid)), + auth.KGID(atomic.LoadUint32(&a.gid)), + ) +} + +// IncLinks implements Inode.IncLinks. +func (a *InodeAttrs) IncLinks(n uint32) { + if atomic.AddUint32(&a.nlink, n) <= n { + panic("InodeLink.IncLinks called with no existing links") + } +} + +// DecLinks implements Inode.DecLinks. +func (a *InodeAttrs) DecLinks() { + if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) { + // Negative overflow + panic("Inode.DecLinks called at 0 links") + } +} + +type slot struct { + Name string + Dentry *vfs.Dentry + slotEntry +} + +// OrderedChildrenOptions contains initialization options for OrderedChildren. +type OrderedChildrenOptions struct { + // Writable indicates whether vfs.FilesystemImpl methods implemented by + // OrderedChildren may modify the tracked children. This applies to + // operations related to rename, unlink and rmdir. If an OrderedChildren is + // not writable, these operations all fail with EPERM. + Writable bool +} + +// OrderedChildren partially implements the Inode interface. OrderedChildren can +// be embedded in directory inodes to keep track of the children in the +// directory, and can then be used to implement a generic directory FD -- see +// GenericDirectoryFD. OrderedChildren is not compatible with dynamic +// directories. +// +// Must be initialize with Init before first use. +type OrderedChildren struct { + refs.AtomicRefCount + + // Can children be modified by user syscalls? It set to false, interface + // methods that would modify the children return EPERM. Immutable. + writable bool + + mu sync.RWMutex + order slotList + set map[string]*slot +} + +// Init initializes an OrderedChildren. +func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { + o.writable = opts.Writable + o.set = make(map[string]*slot) +} + +// DecRef implements Inode.DecRef. +func (o *OrderedChildren) DecRef() { + o.AtomicRefCount.DecRefWithDestructor(o.Destroy) +} + +// Destroy cleans up resources referenced by this OrderedChildren. +func (o *OrderedChildren) Destroy() { + o.mu.Lock() + defer o.mu.Unlock() + o.order.Reset() + o.set = nil +} + +// Populate inserts children into this OrderedChildren, and d's dentry +// cache. Populate returns the number of directories inserted, which the caller +// may use to update the link count for the parent directory. +// +// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory +// inode. children must not contain any conflicting entries already in o. +func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 { + var links uint32 + for name, child := range children { + if child.isDir() { + links++ + } + if err := o.Insert(name, child.VFSDentry()); err != nil { + panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) + } + d.InsertChild(name, child.VFSDentry()) + } + return links +} + +// HasChildren implements Inode.HasChildren. +func (o *OrderedChildren) HasChildren() bool { + o.mu.RLock() + defer o.mu.RUnlock() + return len(o.set) > 0 +} + +// Insert inserts child into o. This ignores the writability of o, as this is +// not part of the vfs.FilesystemImpl interface, and is a lower-level operation. +func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { + o.mu.Lock() + defer o.mu.Unlock() + if _, ok := o.set[name]; ok { + return syserror.EEXIST + } + s := &slot{ + Name: name, + Dentry: child, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) removeLocked(name string) { + if s, ok := o.set[name]; ok { + delete(o.set, name) + o.order.Remove(s) + } +} + +// Precondition: caller must hold o.mu for writing. +func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { + if s, ok := o.set[name]; ok { + // Existing slot with given name, simply replace the dentry. + var old *vfs.Dentry + old, s.Dentry = s.Dentry, new + return old + } + + // No existing slot with given name, create and hash new slot. + s := &slot{ + Name: name, + Dentry: new, + } + o.order.PushBack(s) + o.set[name] = s + return nil +} + +// Precondition: caller must hold o.mu for reading or writing. +func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { + s, ok := o.set[name] + if !ok { + return syserror.ENOENT + } + if s.Dentry != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child)) + } + return nil +} + +// Unlink implements Inode.Unlink. +func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { + if !o.writable { + return syserror.EPERM + } + o.mu.Lock() + defer o.mu.Unlock() + if err := o.checkExistingLocked(name, child); err != nil { + return err + } + o.removeLocked(name) + return nil +} + +// Rmdir implements Inode.Rmdir. +func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { + // We're not responsible for checking that child is a directory, that it's + // empty, or updating any link counts; so this is the same as unlink. + return o.Unlink(ctx, name, child) +} + +type renameAcrossDifferentImplementationsError struct{} + +func (renameAcrossDifferentImplementationsError) Error() string { + return "rename across inodes with different implementations" +} + +// Rename implements Inode.Rename. +// +// Precondition: Rename may only be called across two directory inodes with +// identical implementations of Rename. Practically, this means filesystems that +// implement Rename by embedding OrderedChildren for any directory +// implementation must use OrderedChildren for all directory implementations +// that will support Rename. +// +// Postcondition: reference on any replaced dentry transferred to caller. +func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { + dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) + if !ok { + return nil, renameAcrossDifferentImplementationsError{} + } + if !o.writable || !dst.writable { + return nil, syserror.EPERM + } + // Note: There's a potential deadlock below if concurrent calls to Rename + // refer to the same src and dst directories in reverse. We avoid any + // ordering issues because the caller is required to serialize concurrent + // calls to Rename in accordance with the interface declaration. + o.mu.Lock() + defer o.mu.Unlock() + if dst != o { + dst.mu.Lock() + defer dst.mu.Unlock() + } + if err := o.checkExistingLocked(oldname, child); err != nil { + return nil, err + } + replaced := dst.replaceChildLocked(newname, child) + return replaced, nil +} + +// nthLocked returns an iterator to the nth child tracked by this object. The +// iterator is valid until the caller releases o.mu. Returns nil if the +// requested index falls out of bounds. +// +// Preconditon: Caller must hold o.mu for reading. +func (o *OrderedChildren) nthLocked(i int64) *slot { + for it := o.order.Front(); it != nil && i >= 0; it = it.Next() { + if i == 0 { + return it + } + i-- + } + return nil +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go new file mode 100644 index 000000000..bb01c3d01 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -0,0 +1,405 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package kernfs provides the tools to implement inode-based filesystems. +// Kernfs has two main features: +// +// 1. The Inode interface, which maps VFS2's path-based filesystem operations to +// specific filesystem nodes. Kernfs uses the Inode interface to provide a +// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as +// the synchronization mechanism for all filesystem operations by holding a +// filesystem-wide lock across all operations. +// +// 2. Various utility types which provide generic implementations for various +// parts of the Inode and vfs.FileDescription interfaces. Client filesystems +// based on kernfs can embed the appropriate set of these to avoid having to +// reimplement common filesystem operations. See inode_impl_util.go and +// fd_impl_util.go. +// +// Reference Model: +// +// Kernfs dentries represents named pointers to inodes. Dentries and inode have +// independent lifetimes and reference counts. A child dentry unconditionally +// holds a reference on its parent directory's dentry. A dentry also holds a +// reference on the inode it points to. Multiple dentries can point to the same +// inode (for example, in the case of hardlinks). File descriptors hold a +// reference to the dentry they're opened on. +// +// Dentries are guaranteed to exist while holding Filesystem.mu for +// reading. Dropping dentries require holding Filesystem.mu for writing. To +// queue dentries for destruction from a read critical section, see +// Filesystem.deferDecRef. +// +// Lock ordering: +// +// kernfs.Filesystem.mu +// kernfs.Dentry.dirMu +// vfs.VirtualFilesystem.mountMu +// vfs.Dentry.mu +// kernfs.Filesystem.droppedDentriesMu +// (inode implementation locks, if any) +package kernfs + +import ( + "fmt" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory +// filesystem. Concrete implementations are expected to embed this in their own +// Filesystem type. +type Filesystem struct { + vfsfs vfs.Filesystem + + droppedDentriesMu sync.Mutex + + // droppedDentries is a list of dentries waiting to be DecRef()ed. This is + // used to defer dentry destruction until mu can be acquired for + // writing. Protected by droppedDentriesMu. + droppedDentries []*vfs.Dentry + + // mu synchronizes the lifetime of Dentries on this filesystem. Holding it + // for reading guarantees continued existence of any resolved dentries, but + // the dentry tree may be modified. + // + // Kernfs dentries can only be DecRef()ed while holding mu for writing. For + // example: + // + // fs.mu.Lock() + // defer fs.mu.Unlock() + // ... + // dentry1.DecRef() + // defer dentry2.DecRef() // Ok, will run before Unlock. + // + // If discarding dentries in a read context, use Filesystem.deferDecRef. For + // example: + // + // fs.mu.RLock() + // fs.mu.processDeferredDecRefs() + // defer fs.mu.RUnlock() + // ... + // fs.deferDecRef(dentry) + mu sync.RWMutex + + // nextInoMinusOne is used to to allocate inode numbers on this + // filesystem. Must be accessed by atomic operations. + nextInoMinusOne uint64 +} + +// deferDecRef defers dropping a dentry ref until the next call to +// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. +// +// Precondition: d must not already be pending destruction. +func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { + fs.droppedDentriesMu.Lock() + fs.droppedDentries = append(fs.droppedDentries, d) + fs.droppedDentriesMu.Unlock() +} + +// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the +// droppedDentries list. See comment on Filesystem.mu. +func (fs *Filesystem) processDeferredDecRefs() { + fs.mu.Lock() + fs.processDeferredDecRefsLocked() + fs.mu.Unlock() +} + +// Precondition: fs.mu must be held for writing. +func (fs *Filesystem) processDeferredDecRefsLocked() { + fs.droppedDentriesMu.Lock() + for _, d := range fs.droppedDentries { + d.DecRef() + } + fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse. + fs.droppedDentriesMu.Unlock() +} + +// Init initializes a kernfs filesystem. This should be called from during +// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding +// kernfs. +func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) { + fs.vfsfs.Init(vfsObj, fs) +} + +// VFSFilesystem returns the generic vfs filesystem object. +func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem { + return &fs.vfsfs +} + +// NextIno allocates a new inode number on this filesystem. +func (fs *Filesystem) NextIno() uint64 { + return atomic.AddUint64(&fs.nextInoMinusOne, 1) +} + +// These consts are used in the Dentry.flags field. +const ( + // Dentry points to a directory inode. + dflagsIsDir = 1 << iota + + // Dentry points to a symlink inode. + dflagsIsSymlink +) + +// Dentry implements vfs.DentryImpl. +// +// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a +// named reference to an inode. A dentry generally lives as long as it's part of +// a mounted filesystem tree. Kernfs doesn't cache dentries once all references +// to them are removed. Dentries hold a single reference to the inode they point +// to, and child dentries hold a reference on their parent. +// +// Must be initialized by Init prior to first use. +type Dentry struct { + refs.AtomicRefCount + + vfsd vfs.Dentry + inode Inode + + refs uint64 + + // flags caches useful information about the dentry from the inode. See the + // dflags* consts above. Must be accessed by atomic ops. + flags uint32 + + // dirMu protects vfsd.children for directory dentries. + dirMu sync.Mutex +} + +// Init initializes this dentry. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) Init(inode Inode) { + d.vfsd.Init(d) + d.inode = inode + ftype := inode.Mode().FileType() + if ftype == linux.ModeDirectory { + d.flags |= dflagsIsDir + } + if ftype == linux.ModeSymlink { + d.flags |= dflagsIsSymlink + } +} + +// VFSDentry returns the generic vfs dentry for this kernfs dentry. +func (d *Dentry) VFSDentry() *vfs.Dentry { + return &d.vfsd +} + +// isDir checks whether the dentry points to a directory inode. +func (d *Dentry) isDir() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0 +} + +// isSymlink checks whether the dentry points to a symlink inode. +func (d *Dentry) isSymlink() bool { + return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef() { + d.AtomicRefCount.DecRefWithDestructor(d.destroy) +} + +// Precondition: Dentry must be removed from VFS' dentry cache. +func (d *Dentry) destroy() { + d.inode.DecRef() // IncRef from Init. + d.inode = nil + if parent := d.vfsd.Parent(); parent != nil { + parent.DecRef() // IncRef from Dentry.InsertChild. + } +} + +// InsertChild inserts child into the vfs dentry cache with the given name under +// this dentry. This does not update the directory inode, so calling this on +// it's own isn't sufficient to insert a child into a directory. InsertChild +// updates the link count on d if required. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) InsertChild(name string, child *vfs.Dentry) { + if !d.isDir() { + panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) + } + vfsDentry := d.VFSDentry() + vfsDentry.IncRef() // DecRef in child's Dentry.destroy. + d.dirMu.Lock() + vfsDentry.InsertChild(child, name) + d.dirMu.Unlock() +} + +// The Inode interface maps filesystem-level operations that operate on paths to +// equivalent operations on specific filesystem nodes. +// +// The interface methods are groups into logical categories as sub interfaces +// below. Generally, an implementation for each sub interface can be provided by +// embedding an appropriate type from inode_impl_utils.go. The sub interfaces +// are purely organizational. Methods declared directly in the main interface +// have no generic implementations, and should be explicitly provided by the +// client filesystem. +// +// Generally, implementations are not responsible for tasks that are common to +// all filesystems. These include: +// +// - Checking that dentries passed to methods are of the appropriate file type. +// - Checking permissions. +// - Updating link and reference counts. +// +// Specific responsibilities of implementations are documented below. +type Inode interface { + // Methods related to reference counting. A generic implementation is + // provided by InodeNoopRefCount. These methods are generally called by the + // equivalent Dentry methods. + inodeRefs + + // Methods related to node metadata. A generic implementation is provided by + // InodeAttrs. + inodeMetadata + + // Method for inodes that represent symlink. InodeNotSymlink provides a + // blanket implementation for all non-symlink inodes. + inodeSymlink + + // Method for inodes that represent directories. InodeNotDirectory provides + // a blanket implementation for all non-directory inodes. + inodeDirectory + + // Method for inodes that represent dynamic directories and their + // children. InodeNoDynamicLookup provides a blanket implementation for all + // non-dynamic-directory inodes. + inodeDynamicLookup + + // Open creates a file description for the filesystem object represented by + // this inode. The returned file description should hold a reference on the + // inode for its lifetime. + // + // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry. + Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) +} + +type inodeRefs interface { + IncRef() + DecRef() + TryIncRef() bool + // Destroy is called when the inode reaches zero references. Destroy release + // all resources (references) on objects referenced by the inode, including + // any child dentries. + Destroy() +} + +type inodeMetadata interface { + // CheckPermissions checks that creds may access this inode for the + // requested access type, per the the rules of + // fs/namei.c:generic_permission(). + CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error + + // Mode returns the (struct stat)::st_mode value for this inode. This is + // separated from Stat for performance. + Mode() linux.FileMode + + // Stat returns the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.StatAt. + Stat(fs *vfs.Filesystem) linux.Statx + + // SetStat updates the metadata for this inode. This corresponds to + // vfs.FilesystemImpl.SetStatAt. + SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error +} + +// Precondition: All methods in this interface may only be called on directory +// inodes. +type inodeDirectory interface { + // The New{File,Dir,Node,Symlink} methods below should return a new inode + // hashed into this inode. + // + // These inode constructors are inode-level operations rather than + // filesystem-level operations to allow client filesystems to mix different + // implementations based on the new node's location in the + // filesystem. + + // HasChildren returns true if the directory inode has any children. + HasChildren() bool + + // NewFile creates a new regular file inode. + NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) + + // NewDir creates a new directory inode. + NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) + + // NewLink creates a new hardlink to a specified inode in this + // directory. Implementations should create a new kernfs Dentry pointing to + // target, and update target's link count. + NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) + + // NewSymlink creates a new symbolic link inode. + NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) + + // NewNode creates a new filesystem node for a mknod syscall. + NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) + + // Unlink removes a child dentry from this directory inode. + Unlink(ctx context.Context, name string, child *vfs.Dentry) error + + // RmDir removes an empty child directory from this directory + // inode. Implementations must update the parent directory's link count, + // if required. Implementations are not responsible for checking that child + // is a directory, checking for an empty directory. + RmDir(ctx context.Context, name string, child *vfs.Dentry) error + + // Rename is called on the source directory containing an inode being + // renamed. child should point to the resolved child in the source + // directory. If Rename replaces a dentry in the destination directory, it + // should return the replaced dentry or nil otherwise. + // + // Precondition: Caller must serialize concurrent calls to Rename. + Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) +} + +type inodeDynamicLookup interface { + // Lookup should return an appropriate dentry if name should resolve to a + // child of this dynamic directory inode. This gives the directory an + // opportunity on every lookup to resolve additional entries that aren't + // hashed into the directory. This is only called when the inode is a + // directory. If the inode is not a directory, or if the directory only + // contains a static set of children, the implementer can unconditionally + // return an appropriate error (ENOTDIR and ENOENT respectively). + // + // The child returned by Lookup will be hashed into the VFS dentry tree. Its + // lifetime can be controlled by the filesystem implementation with an + // appropriate implementation of Valid. + // + // Lookup returns the child with an extra reference and the caller owns this + // reference. + Lookup(ctx context.Context, name string) (*vfs.Dentry, error) + + // Valid should return true if this inode is still valid, or needs to + // be resolved again by a call to Lookup. + Valid(ctx context.Context) bool +} + +type inodeSymlink interface { + // Readlink resolves the target of a symbolic link. If an inode is not a + // symlink, the implementation should return EINVAL. + Readlink(ctx context.Context) (string, error) +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go new file mode 100644 index 000000000..f78bb7b04 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -0,0 +1,423 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs_test + +import ( + "bytes" + "fmt" + "io" + "runtime" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const defaultMode linux.FileMode = 01777 +const staticFileContent = "This is sample content for a static test file." + +// RootDentryFn is a generator function for creating the root dentry of a test +// filesystem. See newTestSystem. +type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry + +// TestSystem represents the context for a single test. +type TestSystem struct { + t *testing.T + ctx context.Context + creds *auth.Credentials + vfs *vfs.VirtualFilesystem + mns *vfs.MountNamespace + root vfs.VirtualDentry +} + +// newTestSystem sets up a minimal environment for running a test, including an +// instance of a test filesystem. Tests can control the contents of the +// filesystem by providing an appropriate rootFn, which should return a +// pre-populated root dentry. +func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + v := vfs.New() + v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("Failed to create testfs root mount: %v", err) + } + + s := &TestSystem{ + t: t, + ctx: ctx, + creds: creds, + vfs: v, + mns: mns, + root: mns.Root(), + } + runtime.SetFinalizer(s, func(s *TestSystem) { s.root.DecRef() }) + return s +} + +// PathOpAtRoot constructs a vfs.PathOperation for a path from the +// root of the test filesystem. +// +// Precondition: path should be relative path. +func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation { + return vfs.PathOperation{ + Root: s.root, + Start: s.root, + Pathname: path, + } +} + +// GetDentryOrDie attempts to resolve a dentry referred to by the +// provided path operation. If unsuccessful, the test fails. +func (s *TestSystem) GetDentryOrDie(pop vfs.PathOperation) vfs.VirtualDentry { + vd, err := s.vfs.GetDentryAt(s.ctx, s.creds, &pop, &vfs.GetDentryOptions{}) + if err != nil { + s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err) + } + return vd +} + +func (s *TestSystem) ReadToEnd(fd *vfs.FileDescription) (string, error) { + buf := make([]byte, usermem.PageSize) + bufIOSeq := usermem.BytesIOSequence(buf) + opts := vfs.ReadOptions{} + + var content bytes.Buffer + for { + n, err := fd.Impl().Read(s.ctx, bufIOSeq, opts) + if n == 0 || err != nil { + if err == io.EOF { + err = nil + } + return content.String(), err + } + content.Write(buf[:n]) + } +} + +type fsType struct { + rootFn RootDentryFn +} + +type filesystem struct { + kernfs.Filesystem +} + +type file struct { + kernfs.DynamicBytesFile + content string +} + +func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry { + f := &file{} + f.content = content + f.DynamicBytesFile.Init(creds, fs.NextIno(), f) + + d := &kernfs.Dentry{} + d.Init(f) + return d +} + +func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%s", f.content) + return nil +} + +type attrs struct { + kernfs.InodeAttrs +} + +func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error { + return syserror.EPERM +} + +type readonlyDir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + kernfs.InodeDirectoryNoNewChildren + + kernfs.OrderedChildren + dentry kernfs.Dentry +} + +func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &readonlyDir{} + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +type dir struct { + attrs + kernfs.InodeNotSymlink + kernfs.InodeNoDynamicLookup + + fs *filesystem + dentry kernfs.Dentry + kernfs.OrderedChildren +} + +func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry { + dir := &dir{} + dir.fs = fs + dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode) + dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) + dir.dentry.Init(dir) + + dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) + + return &dir.dentry +} + +func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + fd := &kernfs.GenericDirectoryFD{} + fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags) + return fd.VFSFileDescription(), nil +} + +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + dir := d.fs.newDir(creds, opts.Mode, nil) + dirVFSD := dir.VFSDentry() + if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { + dir.DecRef() + return nil, err + } + d.IncLinks(1) + return dirVFSD, nil +} + +func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { + creds := auth.CredentialsFromContext(ctx) + f := d.fs.newFile(creds, "") + fVFSD := f.VFSDentry() + if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { + f.DecRef() + return nil, err + } + return fVFSD, nil +} + +func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + fs := &filesystem{} + fs.Init(vfsObj) + root := fst.rootFn(creds, fs) + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// -------------------- Remainder of the file are test cases -------------------- + +func TestBasic(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef() +} + +func TestMkdirGetDentry(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/a new directory") + if err := sys.vfs.MkdirAt(sys.ctx, sys.creds, &pop, &vfs.MkdirOptions{Mode: 0755}); err != nil { + t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err) + } + sys.GetDentryOrDie(pop).DecRef() +} + +func TestReadStaticFile(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("file1") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + content, err := sys.ReadToEnd(fd) + if err != nil { + sys.t.Fatalf("Read failed: %v", err) + } + if diff := cmp.Diff(staticFileContent, content); diff != "" { + sys.t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff) + } +} + +func TestCreateNewFileInStaticDir(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir1": fs.newDir(creds, 0755, nil), + }) + }) + + pop := sys.PathOpAtRoot("dir1/newfile") + opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode} + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, opts) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err) + } + + // Close the file. The file should persist. + fd.DecRef() + + fd, err = sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err) + } + fd.DecRef() +} + +// direntCollector provides an implementation for vfs.IterDirentsCallback for +// testing. It simply iterates to the end of a given directory FD and collects +// all dirents emitted by the callback. +type direntCollector struct { + mu sync.Mutex + dirents map[string]vfs.Dirent +} + +// Handle implements vfs.IterDirentsCallback.Handle. +func (d *direntCollector) Handle(dirent vfs.Dirent) bool { + d.mu.Lock() + if d.dirents == nil { + d.dirents = make(map[string]vfs.Dirent) + } + d.dirents[dirent.Name] = dirent + d.mu.Unlock() + return true +} + +// count returns the number of dirents currently in the collector. +func (d *direntCollector) count() int { + d.mu.Lock() + defer d.mu.Unlock() + return len(d.dirents) +} + +// contains checks whether the collector has a dirent with the given name and +// type. +func (d *direntCollector) contains(name string, typ uint8) error { + d.mu.Lock() + defer d.mu.Unlock() + dirent, ok := d.dirents[name] + if !ok { + return fmt.Errorf("No dirent named %q found", name) + } + if dirent.Type != typ { + return fmt.Errorf("Dirent named %q found, but was expecting type %d, got: %+v", name, typ, dirent) + } + return nil +} + +func TestDirFDReadWrite(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, nil) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + // Read/Write should fail for directory FDs. + if _, err := fd.Read(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Read for directory FD failed with unexpected error: %v", err) + } + if _, err := fd.Write(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EISDIR { + sys.t.Fatalf("Wrire for directory FD failed with unexpected error: %v", err) + } +} + +func TestDirFDIterDirents(t *testing.T) { + sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry { + return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{ + // Fill root with nodes backed by various inode implementations. + "dir1": fs.newReadonlyDir(creds, 0755, nil), + "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{ + "dir3": fs.newDir(creds, 0755, nil), + }), + "file1": fs.newFile(creds, staticFileContent), + }) + }) + + pop := sys.PathOpAtRoot("/") + fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{}) + if err != nil { + sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err) + } + defer fd.DecRef() + + collector := &direntCollector{} + if err := fd.IterDirents(sys.ctx, collector); err != nil { + sys.t.Fatalf("IterDirent failed: %v", err) + } + + // Root directory should contain ".", ".." and 3 children: + if collector.count() != 5 { + sys.t.Fatalf("IterDirent returned too many dirents") + } + for _, dirName := range []string{".", "..", "dir1", "dir2"} { + if err := collector.contains(dirName, linux.DT_DIR); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + } + if err := collector.contains("file1", linux.DT_REG); err != nil { + sys.t.Fatalf("IterDirent had unexpected results: %v", err) + } + +} diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index 7e364c5fd..0cc751eb8 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -24,14 +24,19 @@ go_library( "directory.go", "filesystem.go", "memfs.go", + "named_pipe.go", "regular_file.go", "symlink.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs", deps = [ "//pkg/abi/linux", + "//pkg/amutex", + "//pkg/fspath", + "//pkg/sentry/arch", "//pkg/sentry/context", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", @@ -45,6 +50,7 @@ go_test( deps = [ ":memfs", "//pkg/abi/linux", + "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", @@ -54,3 +60,19 @@ go_test( "//pkg/syserror", ], ) + +go_test( + name = "memfs_test", + size = "small", + srcs = ["pipe_test.go"], + embed = [":memfs"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go index a94b17db6..4a7a94a52 100644 --- a/pkg/sentry/fsimpl/memfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go @@ -21,6 +21,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -160,6 +161,8 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -173,10 +176,11 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -186,7 +190,6 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { defer root.DecRef() vd := root vd.IncRef() - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -219,6 +222,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() + vd = vfs.VirtualDentry{} if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -243,6 +248,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -343,6 +350,8 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) { b.Fatalf("stat(%q) failed: %v", filePath, err) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } @@ -356,10 +365,11 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { // Create VFS. vfsObj := vfs.New() vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } + defer mntns.DecRef(vfsObj) var filePathBuilder strings.Builder filePathBuilder.WriteByte('/') @@ -384,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { } defer mountPoint.DecRef() // Create and mount the submount. - if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.NewFilesystemOptions{}); err != nil { + if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) @@ -395,7 +405,6 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { if err != nil { b.Fatalf("failed to walk to mount root: %v", err) } - defer vd.DecRef() for i := depth; i > 0; i-- { name := fmt.Sprintf("%d", i) pop := vfs.PathOperation{ @@ -435,6 +444,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, Mode: 0644, }) + vd.DecRef() if err != nil { b.Fatalf("failed to create file %q: %v", filename, err) } @@ -459,6 +469,14 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) { b.Fatalf("got wrong permissions (%0o)", stat.Mode) } } + // Don't include deferred cleanup in benchmark time. + b.StopTimer() }) } } + +func init() { + // Turn off reference leak checking for a fair comparison between vfs1 and + // vfs2. + refs.SetLeakMode(refs.NoLeakChecking) +} diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go index f79e2d9c8..af4389459 100644 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ b/pkg/sentry/fsimpl/memfs/filesystem.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -159,7 +160,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op return nil, err } } - inode.incRef() // vfsd.IncRef(&fs.vfsfs) + inode.incRef() return vfsd, nil } @@ -233,7 +234,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if err != nil { return err } - _, err = checkCreateLocked(rp, parentVFSD, parentInode) + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) if err != nil { return err } @@ -241,17 +242,49 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - // TODO: actually implement mknod - return syserror.EPERM + + switch opts.Mode.FileType() { + case 0: + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + fallthrough + case linux.ModeRegular: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeNamedPipe: + child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) + parentVFSD.InsertChild(&child.vfsd, pc) + parentInode.impl.(*directory).childList.PushBack(child) + return nil + + case linux.ModeSocket: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeCharacterDevice: + fallthrough + case linux.ModeBlockDevice: + // TODO(b/72101894): We don't support creating block or character + // devices at the moment. + // + // When we start supporting block and character devices, we'll + // need to check for CAP_MKNOD here. + return syserror.EPERM + + default: + // "EINVAL - mode requested creation of something other than a + // regular file, device special file, FIFO or socket." - mknod(2) + return syserror.EINVAL + } } // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Filter out flags that are not supported by memfs. O_DIRECTORY and // O_NOFOLLOW have no effect here (they're handled by VFS by setting - // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + // appropriate bits in rp), but are visible in FD status flags. O_NONBLOCK + // is supported only by pipes. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() @@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return nil, err } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } mustCreate := opts.Flags&linux.O_EXCL != 0 @@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } afterTrailingSymlink: // Walk to the parent directory of the last path component. @@ -320,7 +353,7 @@ afterTrailingSymlink: child := fs.newDentry(childInode) vfsd.InsertChild(&child.vfsd, pc) inode.impl.(*directory).childList.PushBack(child) - return childInode.open(rp, &child.vfsd, opts.Flags, true) + return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true) } // Open existing file or follow symlink. if mustCreate { @@ -336,29 +369,31 @@ afterTrailingSymlink: // symlink target. goto afterTrailingSymlink } - return childInode.open(rp, childVFSD, opts.Flags, false) + return childInode.open(ctx, rp, childVFSD, opts.Flags, false) } -func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(flags) if !afterCreate { if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { return nil, err } } + mnt := rp.Mount() switch impl := i.impl.(type) { case *regularFile: var fd regularFileFD - fd.flags = flags fd.readable = vfs.MayReadFileWithOpenFlags(flags) fd.writable = vfs.MayWriteFileWithOpenFlags(flags) if fd.writable { - if err := rp.Mount().CheckBeginWrite(); err != nil { + if err := mnt.CheckBeginWrite(); err != nil { return nil, err } - // Mount.EndWrite() is called by regularFileFD.Release(). + // mnt.EndWrite() is called by regularFileFD.Release(). } - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) if flags&linux.O_TRUNC != 0 { impl.mu.Lock() impl.data = impl.data[:0] @@ -372,12 +407,15 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte return nil, syserror.EISDIR } var fd directoryFD - fd.vfsfd.Init(&fd, rp.Mount(), vfsd) - fd.flags = flags + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) return &fd.vfsfd, nil case *symlink: // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP + case *namedPipe: + return newNamedPipeFD(ctx, impl, rp, vfsd, flags) default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -542,3 +580,58 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error inode.decLinksLocked() return nil } + +// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. +func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return nil, err + } + // TODO(b/127675828): support extended attributes + return nil, syserror.ENOTSUP +} + +// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. +func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return "", err + } + // TODO(b/127675828): support extended attributes + return "", syserror.ENOTSUP +} + +// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. +func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. +func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + _, _, err := walkExistingLocked(rp) + if err != nil { + return err + } + // TODO(b/127675828): support extended attributes + return syserror.ENOTSUP +} + +// PrependPath implements vfs.FilesystemImpl.PrependPath. +func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error { + fs.mu.RLock() + defer fs.mu.RUnlock() + return vfs.GenericPrependPath(vfsroot, vd, b) +} diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index b78471c0f..9d509f6e4 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -52,10 +52,10 @@ type filesystem struct { nextInoMinusOne uint64 // accessed using atomic memory operations } -// NewFilesystem implements vfs.FilesystemType.NewFilesystem. -func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { var fs filesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) root := fs.newDentry(fs.newDirectory(creds, 01777)) return &fs.vfsfs, &root.vfsd, nil } @@ -99,17 +99,17 @@ func (fs *filesystem) newDentry(inode *inode) *dentry { } // IncRef implements vfs.DentryImpl.IncRef. -func (d *dentry) IncRef(vfsfs *vfs.Filesystem) { +func (d *dentry) IncRef() { d.inode.incRef() } // TryIncRef implements vfs.DentryImpl.TryIncRef. -func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool { +func (d *dentry) TryIncRef() bool { return d.inode.tryIncRef() } // DecRef implements vfs.DentryImpl.DecRef. -func (d *dentry) DecRef(vfsfs *vfs.Filesystem) { +func (d *dentry) DecRef() { d.inode.decRef() } @@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS stat.Size = uint64(len(impl.target)) stat.Blocks = allocatedBlocksForSize(stat.Size) + case *namedPipe: + stat.Mode |= linux.S_IFIFO default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } @@ -259,28 +261,14 @@ func (i *inode) direntType() uint8 { type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl - - flags uint32 // status flags; immutable } func (fd *fileDescription) filesystem() *filesystem { - return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem) + return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem) } func (fd *fileDescription) inode() *inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode -} - -// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. -func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { - return fd.flags, nil -} - -// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags. -func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - // None of the flags settable by fcntl(F_SETFL) are supported, so this is a - // no-op. - return nil + return fd.vfsfd.Dentry().Impl().(*dentry).inode } // Stat implements vfs.FileDescriptionImpl.Stat. diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go new file mode 100644 index 000000000..d5060850e --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/named_pipe.go @@ -0,0 +1,62 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +type namedPipe struct { + inode inode + + pipe *pipe.VFSPipe +} + +// Preconditions: +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. +func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { + file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // Only the parent has a link. + return &file.inode +} + +// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented +// entirely via struct embedding. +type namedPipeFD struct { + fileDescription + + *pipe.VFSPipeFD +} + +func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + var err error + var fd namedPipeFD + fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags) + if err != nil { + return nil, err + } + mnt := rp.Mount() + mnt.IncRef() + vfsd.IncRef() + fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) + return &fd.vfsfd, nil +} diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go new file mode 100644 index 000000000..5bf527c80 --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/pipe_test.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const fileName = "mypipe" + +func TestSeparateFDs(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side. This is done in a concurrently because opening + // One end the pipe blocks until the other end is opened. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + rfdchan := make(chan *vfs.FileDescription) + go func() { + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} + rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + rfdchan <- rfd + }() + + // Open the write side. + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + rfd, ok := <-rfdchan + if !ok { + t.Fatalf("failed to open pipe for reading %q", fileName) + } + defer rfd.DecRef() + + const msg = "vamos azul" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingRead(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side as nonblocking. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} + rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) + } + defer rfd.DecRef() + + // Open the write side. + openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + const msg = "geh blau" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingWriteError(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the write side as nonblocking, which should return ENXIO. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} + _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != syserror.ENXIO { + t.Fatalf("expected ENXIO, but got error: %v", err) + } +} + +func TestSingleFD(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the pipe as readable and writable. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} + fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer fd.DecRef() + + const msg = "forza blu" + checkEmpty(ctx, t, fd) + checkWrite(ctx, t, fd, msg) + checkRead(ctx, t, fd, msg) +} + +// setup creates a VFS with a pipe in the root directory at path fileName. The +// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal +// upon failure. +func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("failed to create tmpfs root mount: %v", err) + } + + // Create the pipe. + root := mntns.Root() + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} + if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { + t.Fatalf("failed to create file %q: %v", fileName, err) + } + + // Sanity check: the file pipe exists and has the correct mode. + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat(%q) failed: %v", fileName, err) + } + if stat.Mode&^linux.S_IFMT != 0644 { + t.Errorf("got wrong permissions (%0o)", stat.Mode) + } + if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { + t.Errorf("got wrong file type (%0o)", stat.Mode) + } + + return ctx, creds, vfsObj, root +} + +// checkEmpty calls t.Fatal if the pipe in fd is not empty. +func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + readData := make([]byte, 1) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != syserror.ErrWouldBlock { + t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) + } + if bytesRead != 0 { + t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) + } +} + +// checkWrite calls t.Fatal if it fails to write all of msg to fd. +func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + writeData := []byte(msg) + src := usermem.BytesIOSequence(writeData) + bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("error writing to pipe %q: %v", fileName, err) + } + if bytesWritten != int64(len(writeData)) { + t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) + } +} + +// checkRead calls t.Fatal if it fails to read msg from fd. +func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + readData := make([]byte, len(msg)) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{}) + if err != nil { + t.Fatalf("error reading from pipe %q: %v", fileName, err) + } + if bytesRead != int64(len(msg)) { + t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) + } + if !bytes.Equal(readData, []byte(msg)) { + t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) + } +} diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go index c36c4aff5..0e016bca5 100644 --- a/pkg/sentry/fsimpl/proc/filesystems.go +++ b/pkg/sentry/fsimpl/proc/filesystems.go @@ -19,7 +19,7 @@ package proc // +stateify savable type filesystemsData struct{} -// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for +// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for // filesystemsData. We would need to retrive filesystem names from // vfs.VirtualFilesystem. Also needs vfs replacement for // fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev. diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go index e81b1e910..8683cf677 100644 --- a/pkg/sentry/fsimpl/proc/mounts.go +++ b/pkg/sentry/fsimpl/proc/mounts.go @@ -16,7 +16,7 @@ package proc import "gvisor.dev/gvisor/pkg/sentry/kernel" -// TODO(b/138862512): Implement mountInfoFile and mountsFile. +// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile. // mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo. // diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index c46e05c3a..0d87be52b 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -229,6 +229,10 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps) fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode()) + // We unconditionally report a single NUMA node. See + // pkg/sentry/syscalls/linux/sys_mempolicy.go. + fmt.Fprintf(buf, "Mems_allowed:\t1\n") + fmt.Fprintf(buf, "Mems_allowed_list:\t0\n") return nil } diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index d5284f0d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80f227dbe..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. @@ -153,3 +165,23 @@ type Route struct { // GatewayAddr is the route gateway address (RTA_GATEWAY). GatewayAddr []byte } + +// Below SNMP metrics are from Linux/usr/include/linux/snmp.h. + +// StatSNMPIP describes Ip line of /proc/net/snmp. +type StatSNMPIP [19]uint64 + +// StatSNMPICMP describes Icmp line of /proc/net/snmp. +type StatSNMPICMP [27]uint64 + +// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp. +type StatSNMPICMPMSG [512]uint64 + +// StatSNMPTCP describes Tcp line of /proc/net/snmp. +type StatSNMPTCP [15]uint64 + +// StatSNMPUDP describes Udp line of /proc/net/snmp. +type StatSNMPUDP [8]uint64 + +// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. +type StatSNMPUDPLite [8]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e041c51b3..2706927ff 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,12 +209,12 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", + "//pkg/syncutil", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/stack", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 51de4568a..04c244447 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "Credentials", }, diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index e3f5b0d83..3c9dceaba 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,6 +15,8 @@ package kernel import ( + "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task { return nil } +// Deadline implements context.Context.Deadline. +func (*Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (*Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (*Task) Err() error { + return nil +} + // AsyncContext returns a context.Context that may be used by goroutines that // do work on behalf of t and therefore share its contextual values, but are // not t's task goroutine (e.g. asynchronous I/O). @@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool { return ctx.t.IsLogging(level) } +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + // Value implements context.Context.Value. func (ctx taskAsyncContext) Value(key interface{}) interface{} { return ctx.t.Value(key) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index cc3f43a45..11f613a11 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -81,6 +81,9 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` + // next is start position to find fd. + next int32 + // used contains the number of non-nil entries. It must be accessed // atomically. It may be read atomically without holding mu (but not // written). @@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.mu.Lock() defer f.mu.Unlock() + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { @@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, syscall.EMFILE } + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + return fds, nil } @@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File { f.mu.Lock() defer f.mu.Unlock() + // Update current available position. + if fd < f.next { + f.next = fd + } + orig, _, _ := f.get(fd) if orig != nil { orig.IncRef() // Reference for caller. @@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) { f.forEach(func(fd int32, file *fs.File, flags FDFlags) { if cond(file, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. + // Update current available position. + if fd < f.next { + f.next = fd + } } }) } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2413788e7..2bcb6216a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) { if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) } + + i := int32(2) + fdTable.Remove(i) + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { + t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) + } + }) +} + +func TestFDTableOverLimit(t *testing.T) { + runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { + if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") + } + + if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { + t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) + } else { + for _, fd := range fds { + fdTable.Remove(fd) + } + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { + t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) + } + + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { + t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) + } else if len(fds) != 1 || fds[0] != 0 { + t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) + } }) } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 34286c7a8..75ec31761 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//third_party/gvsync:generic_atomicptr", + template = "//pkg/syncutil:generic_atomicptr", types = { "Value": "bucket", }, diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 3cda03891..bd3fb4c03 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(w, k.FeatureSet(), nil); err != nil { + if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) @@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Save(w, k, &stats); err != nil { + if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { return err } log.Infof("Kernel save stats: %s", &stats) @@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(w); err != nil { + if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(r, &features, nil); err != nil { + if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Load(r, k, &stats); err != nil { + if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { return err } log.Infof("Kernel load stats: %s", &stats) @@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(r); err != nil { + if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -804,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: mounts, + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: true, + Filename: args.Filename, + File: args.File, + CloseOnExec: false, + Argv: args.Argv, + Envv: args.Envv, + Features: k.featureSet, + } - tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + tc, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -828,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, AbstractSocketNamespace: args.AbstractSocketNamespace, ContainerID: args.ContainerID, } - if _, err := k.tasks.NewTask(config); err != nil { + t, err := k.tasks.NewTask(config) + if err != nil { return nil, 0, err } + t.traceExecEvent(tc) // Simulate exec for tracing. // Success. tgid := k.tasks.Root.IDOfThreadGroup(tg) @@ -1105,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { return lastErr } +// RebuildTraceContexts rebuilds the trace context for all tasks. +// +// Unfortunately, if these are built while tracing is not enabled, then we will +// not have meaningful trace data. Rebuilding here ensures that we can do so +// after tracing has been enabled. +func (k *Kernel) RebuildTraceContexts() { + k.extMu.Lock() + defer k.extMu.Unlock() + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + + for t, tid := range k.tasks.Root.tids { + t.rebuildTraceContext(tid) + } +} + // FeatureSet returns the FeatureSet. func (k *Kernel) FeatureSet() *cpuid.FeatureSet { return k.featureSet @@ -1309,6 +1340,7 @@ func (k *Kernel) ListSockets() []*SocketEntry { return socks } +// supervisorContext is a privileged context. type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index cde647139..9d34f6d4d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -24,8 +24,10 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_util.go", "reader.go", "reader_writer.go", + "vfs.go", "writer.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", @@ -40,6 +42,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index a2dc72204..4a19ab7ce 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi switch { case flags.Read && !flags.Write: // O_RDONLY. r := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) + newHandleLocked(&i.rWakeup) if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { - if !i.waitFor(&i.wWakeup, ctx) { + if !waitFor(&i.mu, &i.wWakeup, ctx) { r.DecRef() return nil, syserror.ErrInterrupted } @@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Write && !flags.Read: // O_WRONLY. w := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.wWakeup) if i.p.isNamed && !i.p.HasReaders() { // On a nonblocking, write-only open, the open fails with ENXIO if the @@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return nil, syserror.ENXIO } - if !i.waitFor(&i.rWakeup, ctx) { + if !waitFor(&i.mu, &i.rWakeup, ctx) { w.DecRef() return nil, syserror.ErrInterrupted } @@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Read && flags.Write: // O_RDWR. // Pipes opened for read-write always succeeds without blocking. rw := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.rWakeup) + newHandleLocked(&i.wWakeup) return rw, nil default: @@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi } } -// waitFor blocks until the underlying pipe has at least one reader/writer is -// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this -// function will block for either readers or writers, depending on where -// 'wakeupChan' points. -// -// f.mu must be held by the caller. waitFor returns with f.mu held, but it will -// drop f.mu before blocking for any reader/writers. -func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { - // Ideally this function would simply use a condition variable. However, the - // wait needs to be interruptible via 'sleeper', so we must sychronize via a - // channel. The synchronization below relies on the fact that closing a - // channel unblocks all receives on the channel. - - // Does an appropriate wakeup channel already exist? If not, create a new - // one. This is all done under f.mu to avoid races. - if *wakeupChan == nil { - *wakeupChan = make(chan struct{}) - } - - // Grab a local reference to the wakeup channel since it may disappear as - // soon as we drop f.mu. - wakeup := *wakeupChan - - // Drop the lock and prepare to sleep. - i.mu.Unlock() - cancel := sleeper.SleepStart() - - // Wait for either a new reader/write to be signalled via 'wakeup', or - // for the sleep to be cancelled. - select { - case <-wakeup: - sleeper.SleepFinish(true) - case <-cancel: - sleeper.SleepFinish(false) - } - - // Take the lock and check if we were woken. If we were woken and - // interrupted, the former takes priority. - i.mu.Lock() - select { - case <-wakeup: - return true - default: - return false - } -} - -// newHandleLocked signals a new pipe reader or writer depending on where -// 'wakeupChan' points. This unblocks any corresponding reader or writer -// waiting for the other end of the channel to be opened, see Fifo.waitFor. -// -// i.mu must be held. -func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) { - if *wakeupChan != nil { - close(*wakeupChan) - *wakeupChan = nil - } -} - func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error { return syserror.EPIPE } diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 8e4e8e82e..1a1b38f83 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -111,11 +111,27 @@ func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { if atomicIOBytes > sizeBytes { atomicIOBytes = sizeBytes } - return &Pipe{ - isNamed: isNamed, - max: sizeBytes, - atomicIOBytes: atomicIOBytes, + var p Pipe + initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + return &p +} + +func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { + if sizeBytes < MinimumPipeSize { + sizeBytes = MinimumPipeSize + } + if sizeBytes > MaximumPipeSize { + sizeBytes = MaximumPipeSize + } + if atomicIOBytes <= 0 { + atomicIOBytes = 1 + } + if atomicIOBytes > sizeBytes { + atomicIOBytes = sizeBytes } + pipe.isNamed = isNamed + pipe.max = sizeBytes + pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go new file mode 100644 index 000000000..ef9641e6a --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -0,0 +1,213 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "io" + "math" + "sync" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains Pipe file functionality that is tied to neither VFS nor +// the old fs architecture. + +// Release cleans up the pipe's state. +func (p *Pipe) Release() { + p.rClose() + p.wClose() + + // Wake up readers and writers. + p.Notify(waiter.EventIn | waiter.EventOut) +} + +// Read reads from the Pipe into dst. +func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { + n, err := p.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo writes to w from the Pipe. +func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return p.dup(ctx, ops) + } + n, err := p.read(ctx, ops) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// Write writes to the Pipe from src. +func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom reads from r to the Pipe. +func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// Readiness returns the ready events in the underlying pipe. +func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask { + return p.rwReadiness() & mask +} + +// Ioctl implements ioctls on the Pipe. +func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + // Switch on ioctl request. + switch int(args[1].Int()) { + case linux.FIONREAD: + v := p.queued() + if v > math.MaxInt32 { + v = math.MaxInt32 // Silently truncate. + } + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + default: + return 0, syscall.ENOTTY + } +} + +// waitFor blocks until the underlying pipe has at least one reader/writer is +// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this +// function will block for either readers or writers, depending on where +// 'wakeupChan' points. +// +// mu must be held by the caller. waitFor returns with mu held, but it will +// drop mu before blocking for any reader/writers. +func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { + // Ideally this function would simply use a condition variable. However, the + // wait needs to be interruptible via 'sleeper', so we must sychronize via a + // channel. The synchronization below relies on the fact that closing a + // channel unblocks all receives on the channel. + + // Does an appropriate wakeup channel already exist? If not, create a new + // one. This is all done under f.mu to avoid races. + if *wakeupChan == nil { + *wakeupChan = make(chan struct{}) + } + + // Grab a local reference to the wakeup channel since it may disappear as + // soon as we drop f.mu. + wakeup := *wakeupChan + + // Drop the lock and prepare to sleep. + mu.Unlock() + cancel := sleeper.SleepStart() + + // Wait for either a new reader/write to be signalled via 'wakeup', or + // for the sleep to be cancelled. + select { + case <-wakeup: + sleeper.SleepFinish(true) + case <-cancel: + sleeper.SleepFinish(false) + } + + // Take the lock and check if we were woken. If we were woken and + // interrupted, the former takes priority. + mu.Lock() + select { + case <-wakeup: + return true + default: + return false + } +} + +// newHandleLocked signals a new pipe reader or writer depending on where +// 'wakeupChan' points. This unblocks any corresponding reader or writer +// waiting for the other end of the channel to be opened, see Fifo.waitFor. +// +// Precondition: the mutex protecting wakeupChan must be held. +func newHandleLocked(wakeupChan *chan struct{}) { + if *wakeupChan != nil { + close(*wakeupChan) + *wakeupChan = nil + } +} diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index 7c307f013..b4d29fc77 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -16,16 +16,12 @@ package pipe import ( "io" - "math" - "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) // ReaderWriter satisfies the FileOperations interface and services both @@ -45,124 +41,27 @@ type ReaderWriter struct { *Pipe } -// Release implements fs.FileOperations.Release. -func (rw *ReaderWriter) Release() { - rw.Pipe.rClose() - rw.Pipe.wClose() - - // Wake up readers and writers. - rw.Pipe.Notify(waiter.EventIn | waiter.EventOut) -} - // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(buf *buffer) (int64, error) { - n, err := dst.CopyOutFrom(ctx, buf) - dst = dst.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.Read(ctx, dst) } // WriteTo implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(buf *buffer) (int64, error) { - n, err := buf.ReadToWriter(w, count, dup) - count -= n - return n, err - }, - } - if dup { - // There is no notification for dup operations. - return rw.Pipe.dup(ctx, ops) - } - n, err := rw.Pipe.read(ctx, ops) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.WriteTo(ctx, w, count, dup) } // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(buf *buffer) (int64, error) { - n, err := src.CopyInTo(ctx, buf) - src = src.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err + return rw.Pipe.Write(ctx, src) } // ReadFrom implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(buf *buffer) (int64, error) { - n, err := buf.WriteFromReader(r, count) - count -= n - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err -} - -// Readiness returns the ready events in the underlying pipe. -func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask { - return rw.Pipe.rwReadiness() & mask + return rw.Pipe.ReadFrom(ctx, r, count) } // Ioctl implements fs.FileOperations.Ioctl. func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // Switch on ioctl request. - switch int(args[1].Int()) { - case linux.FIONREAD: - v := rw.queued() - if v > math.MaxInt32 { - v = math.MaxInt32 // Silently truncate. - } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - default: - return 0, syscall.ENOTTY - } + return rw.Pipe.Ioctl(ctx, io, args) } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go new file mode 100644 index 000000000..6416e0dd8 --- /dev/null +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains types enabling the pipe package to be used with the vfs +// package. + +// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should +// not be copied. +type VFSPipe struct { + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // pipe is the underlying pipe. + pipe Pipe + + // Channels for synchronizing the creation of new readers and writers + // of this fifo. See waitFor and newHandleLocked. + // + // These are not saved/restored because all waiters are unblocked on + // save, and either automatically restart (via ERESTARTSYS) or return + // EINTR on resume. On restarts via ERESTARTSYS, the appropriate + // channel will be recreated. + rWakeup chan struct{} `state:"nosave"` + wWakeup chan struct{} `state:"nosave"` +} + +// NewVFSPipe returns an initialized VFSPipe. +func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe { + var vp VFSPipe + initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes) + return &vp +} + +// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics +// during open: +// +// "Normally, opening the FIFO blocks until the other end is opened also. A +// process can open a FIFO in nonblocking mode. In this case, opening for +// read-only will succeed even if no-one has opened on the write side yet, +// opening for write-only will fail with ENXIO (no such device or address) +// unless the other end has already been opened. Under Linux, opening a FIFO +// for read and write will succeed both in blocking and nonblocking mode. POSIX +// leaves this behavior undefined. This can be used to open a FIFO for writing +// while there are no readers available." - fifo(7) +func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + vp.mu.Lock() + defer vp.mu.Unlock() + + readable := vfs.MayReadFileWithOpenFlags(flags) + writable := vfs.MayWriteFileWithOpenFlags(flags) + if !readable && !writable { + return nil, syserror.EINVAL + } + + vfd, err := vp.open(rp, vfsd, vfsfd, flags) + if err != nil { + return nil, err + } + + switch { + case readable && writable: + // Pipes opened for read-write always succeed without blocking. + newHandleLocked(&vp.rWakeup) + newHandleLocked(&vp.wWakeup) + + case readable: + newHandleLocked(&vp.rWakeup) + // If this pipe is being opened as nonblocking and there's no + // writer, we have to wait for a writer to open the other end. + if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + return nil, syserror.EINTR + } + + case writable: + newHandleLocked(&vp.wWakeup) + + if !vp.pipe.HasReaders() { + // Nonblocking, write-only opens fail with ENXIO when + // the read side isn't open yet. + if flags&linux.O_NONBLOCK != 0 { + return nil, syserror.ENXIO + } + // Wait for a reader to open the other end. + if !waitFor(&vp.mu, &vp.rWakeup, ctx) { + return nil, syserror.EINTR + } + } + + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return vfd, nil +} + +// Preconditions: vp.mu must be held. +func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + var fd VFSPipeFD + fd.flags = flags + fd.readable = vfs.MayReadFileWithOpenFlags(flags) + fd.writable = vfs.MayWriteFileWithOpenFlags(flags) + fd.vfsfd = vfsfd + fd.pipe = &vp.pipe + if fd.writable { + // The corresponding Mount.EndWrite() is in VFSPipe.Release(). + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + } + + switch { + case fd.readable && fd.writable: + vp.pipe.rOpen() + vp.pipe.wOpen() + case fd.readable: + vp.pipe.rOpen() + case fd.writable: + vp.pipe.wOpen() + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return &fd, nil +} + +// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is +// expected that filesystesm will use this in a struct implementing +// vfs.FileDescriptionImpl. +type VFSPipeFD struct { + pipe *Pipe + flags uint32 + readable bool + writable bool + vfsfd *vfs.FileDescription +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *VFSPipeFD) Release() { + var event waiter.EventMask + if fd.readable { + fd.pipe.rClose() + event |= waiter.EventIn + } + if fd.writable { + fd.pipe.wClose() + event |= waiter.EventOut + } + if event == 0 { + panic("invalid pipe flags: must be readable, writable, or both") + } + + if fd.writable { + fd.vfsfd.VirtualDentry().Mount().EndWrite() + } + + fd.pipe.Notify(event) +} + +// OnClose implements vfs.FileDescriptionImpl.OnClose. +func (fd *VFSPipeFD) OnClose(_ context.Context) error { + return nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + if !fd.readable { + return 0, syserror.EINVAL + } + + return fd.pipe.Read(ctx, dst) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + if !fd.writable { + return 0, syserror.EINVAL + } + + return fd.pipe.Write(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.pipe.Ioctl(ctx, uio, args) +} diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go index 0acdf769d..61e412911 100644 --- a/pkg/sentry/kernel/ptrace_arm64.go +++ b/pkg/sentry/kernel/ptrace_arm64.go @@ -17,7 +17,6 @@ package kernel import ( - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 93fe68a3e..de9617e9d 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred return syserror.ERANGE } - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = val sem.pid = pid s.changeTime = ktime.NowFromContext(ctx) @@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti for i, val := range vals { sem := &s.sems[i] - // TODO(b/29354920): Clear undo entries in all processes + // TODO(gvisor.dev/issue/137): Clear undo entries in all processes. sem.value = int16(val) sem.pid = pid sem.wakeWaiters() @@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch } // All operations succeeded, apply them. - // TODO(b/29354920): handle undo operations. + // TODO(gvisor.dev/issue/137): handle undo operations. for i, v := range tmpVals { s.sems[i].value = v s.sems[i].wakeWaiters() diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 50b69d154..9f7e19b4d 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "signalfd", srcs = ["signalfd.go"], diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 220fa73a2..2fdee0282 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn { return nil } +// LookupName looks up a syscall name. +func (s *SyscallTable) LookupName(sysno uintptr) string { + if sc, ok := s.Table[sysno]; ok { + return sc.Name + } + return fmt.Sprintf("sys_%d", sysno) // Unlikely. +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index c82ef5486..ab0c6c4aa 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -15,6 +15,8 @@ package kernel import ( + gocontext "context" + "runtime/trace" "sync" "sync/atomic" @@ -35,8 +37,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syncutil" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/third_party/gvsync" ) // Task represents a thread of execution in the untrusted app. It @@ -83,7 +85,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq gvsync.SeqCount `state:"nosave"` + goschedSeq syncutil.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called @@ -390,7 +392,14 @@ type Task struct { // logPrefix is a string containing the task's thread ID in the root PID // namespace, and is prepended to log messages emitted by Task.Infof etc. - logPrefix atomic.Value `state:".(string)"` + logPrefix atomic.Value `state:"nosave"` + + // traceContext and traceTask are both used for tracing, and are + // updated along with the logPrefix in updateInfoLocked. + // + // These are exclusive to the task goroutine. + traceContext gocontext.Context `state:"nosave"` + traceTask *trace.Task `state:"nosave"` // creds is the task's credentials. // @@ -528,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) { t.ptraceTracer.Store(tracer) } -func (t *Task) saveLogPrefix() string { - return t.logPrefix.Load().(string) -} - -func (t *Task) loadLogPrefix(prefix string) { - t.logPrefix.Store(prefix) -} - func (t *Task) saveSyscallFilters() []bpf.Program { if f := t.syscallFilters.Load(); f != nil { return f.([]bpf.Program) @@ -549,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) { // afterLoad is invoked by stateify. func (t *Task) afterLoad() { + t.updateInfoLocked() t.interruptChan = make(chan struct{}, 1) t.gosched.State = TaskGoroutineNonexistent if t.stop != nil { @@ -709,9 +711,9 @@ func (t *Task) FDTable() *FDTable { return t.fdTable } -// GetFile is a convenience wrapper t.FDTable().GetFile. +// GetFile is a convenience wrapper for t.FDTable().Get. // -// Precondition: same as FDTable. +// Precondition: same as FDTable.Get. func (t *Task) GetFile(fd int32) *fs.File { f, _ := t.fdTable.Get(fd) return f diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index dd69939f9..4a4a69ee2 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -16,6 +16,7 @@ package kernel import ( "runtime" + "runtime/trace" "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { runtime.Gosched() } + region := trace.StartRegion(t.traceContext, blockRegion) select { case <-C: + region.End() t.SleepFinish(true) + // Woken by event. return nil case <-interrupt: + region.End() t.SleepFinish(false) // Return the indicated error on interrupt. return syserror.ErrInterrupted case <-timerChan: - // We've timed out. + region.End() t.SleepFinish(true) + // We've timed out. return syserror.ETIMEDOUT } } diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 0916fd658..3eadfedb4 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // nt that it must receive before its task goroutine starts running. tid := nt.k.tasks.Root.IDOfTask(nt) defer nt.Start(tid) + t.traceCloneEvent(tid) // "If fork/clone and execve are allowed by @prog, any child processes will // be constrained to the same filters and system call ABI as the parent." - diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 8639d379f..bb5560acf 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,10 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} } -// LoadTaskImage loads filename into a new TaskContext. +// LoadTaskImage loads a specified file into a new TaskContext. // -// It takes several arguments: -// * mounts: MountNamespace to lookup filename in -// * root: Root to lookup filename under -// * wd: Working directory to lookup filename under -// * maxTraversals: maximum number of symlinks to follow -// * filename: path to binary to load -// * file: an open fs.File object of the binary to load. If set, -// file will be loaded and not filename. -// * argv: Binary argv -// * envv: Binary envv -// * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving filename. - if file != nil { - filename = file.MappedName(ctx) +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.MappedName(ctx) } // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) + args.MemoryManager = m - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 17a089b90..90a6190f1 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct { } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { + t.traceExecEvent(r.tc) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { @@ -253,7 +254,7 @@ func (t *Task) promoteLocked() { t.tg.leader = t t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t]) - t.updateLogPrefixLocked() + t.updateInfoLocked() // Reap the original leader. If it has a tracer, detach it instead of // waiting for it to acknowledge the original leader's death. oldLeader.exitParentNotified = true diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 535f03e50..435761e5a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState { type runExitMain struct{} func (*runExitMain) execute(t *Task) taskRunState { + t.traceExitEvent() lastExiter := t.exitThreadGroup() // If the task has a cleartid, and the thread group wasn't killed by a diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index a29e9b9eb..0fb3661de 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -16,6 +16,7 @@ package kernel import ( "fmt" + "runtime/trace" "sort" "gvisor.dev/gvisor/pkg/log" @@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() { } } -// updateLogPrefix updates the task's cached log prefix to reflect its -// current thread ID. +// trace definitions. +// +// Note that all region names are prefixed by ':' in order to ensure that they +// are lexically ordered before all system calls, which use the naked system +// call name (e.g. "read") for maximum clarity. +const ( + traceCategory = "task" + runRegion = ":run" + blockRegion = ":block" + cpuidRegion = ":cpuid" + faultRegion = ":fault" +) + +// updateInfoLocked updates the task's cached log prefix and tracing +// information to reflect its current thread ID. // // Preconditions: The task's owning TaskSet.mu must be locked. -func (t *Task) updateLogPrefixLocked() { +func (t *Task) updateInfoLocked() { // Use the task's TID in the root PID namespace for logging. - t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t])) + tid := t.tg.pidns.owner.Root.tids[t] + t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid)) + t.rebuildTraceContext(tid) +} + +// rebuildTraceContext rebuilds the trace context. +// +// Precondition: the passed tid must be the tid in the root namespace. +func (t *Task) rebuildTraceContext(tid ThreadID) { + // Re-initialize the trace context. + if t.traceTask != nil { + t.traceTask.End() + } + + // Note that we define the "task type" to be the dynamic TID. This does + // not align perfectly with the documentation for "tasks" in the + // tracing package. Tasks may be assumed to be bounded by analysis + // tools. However, if we just use a generic "task" type here, then the + // "user-defined tasks" page on the tracing dashboard becomes nearly + // unusable, as it loads all traces from all tasks. + // + // We can assume that the number of tasks in the system is not + // arbitrarily large (in general it won't be, especially for cases + // where we're collecting a brief profile), so using the TID is a + // reasonable compromise in this case. + t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) +} + +// traceCloneEvent is called when a new task is spawned. +// +// ntid must be the new task's ThreadID in the root namespace. +func (t *Task) traceCloneEvent(ntid ThreadID) { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid) +} + +// traceExitEvent is called when a task exits. +func (t *Task) traceExitEvent() { + if !trace.IsEnabled() { + return + } + trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status()) +} + +// traceExecEvent is called when a task calls exec. +func (t *Task) traceExecEvent(tc *TaskContext) { + if !trace.IsEnabled() { + return + } + d := tc.MemoryManager.Executable() + if d == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") + return + } + defer d.DecRef() + root := t.fsContext.RootDirectory() + if root == nil { + trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>") + return + } + defer root.DecRef() + n, _ := d.FullName(root) + trace.Logf(t.traceContext, traceCategory, "exec: %s", n) } diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index c92266c59..d97f8c189 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -17,6 +17,7 @@ package kernel import ( "bytes" "runtime" + "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState { t.tg.pidns.owner.mu.RUnlock() } + region := trace.StartRegion(t.traceContext, runRegion) t.accountTaskGoroutineEnter(TaskGoroutineRunningApp) info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU) t.accountTaskGoroutineLeave(TaskGoroutineRunningApp) + region.End() if clearSinglestep { t.Arch().ClearSingleStep() @@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState { case platform.ErrContextSignalCPUID: // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) @@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() // Resume execution. return (*runApp)(nil) } + region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. @@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState { // an application-generated signal and we should continue execution // normally. if at.Any() { + region := trace.StartRegion(t.traceContext, faultRegion) addr := usermem.Addr(info.Addr()) err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack())) + region.End() if err == nil { // The fault was handled appropriately. // We can resume running the application. @@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState { } // Is this a vsyscall that we need emulate? + // + // Note that we don't track vsyscalls as part of a + // specific trace region. This is because regions don't + // stack, and the actual system call will count as a + // region. We should be able to easily identify + // vsyscalls by having a <fault><syscall> pair. if at.Execute { if sysno, ok := t.tc.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index ae6fc4025..3522a4ae5 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { // Below this point, newTask is expected not to fail (there is no rollback // of assignTIDsLocked or any of the following). - // Logging on t's behalf will panic if t.logPrefix hasn't been initialized. - // This is the earliest point at which we can do so (since t now has thread - // IDs). - t.updateLogPrefixLocked() + // Logging on t's behalf will panic if t.logPrefix hasn't been + // initialized. This is the earliest point at which we can do so + // (since t now has thread IDs). + t.updateInfoLocked() if cfg.InheritParent != nil { t.parent = cfg.InheritParent.parent diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index b543d536a..3180f5560 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -17,6 +17,7 @@ package kernel import ( "fmt" "os" + "runtime/trace" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u ctrl = ctrlStopAndReinvokeSyscall } else { fn := s.Lookup(sysno) + var region *trace.Region // Only non-nil if tracing == true. + if trace.IsEnabled() { + region = trace.StartRegion(t.traceContext, s.LookupName(sysno)) + } if fn != nil { // Call our syscall implementation. rval, ctrl, err = fn(t, args) @@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u // Use the missing function if not found. rval, err = t.SyscallTable().Missing(t, sysno, args) } + if region != nil { + region.End() + } } if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) { diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 34f84487a..048de26dc 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -21,8 +21,19 @@ import "sync" // // +stateify savable type TTY struct { + // Index is the terminal index. It is immutable. + Index uint32 + mu sync.Mutex `state:"nosave"` // tg is protected by mu. tg *ThreadGroup } + +// TTY returns the thread group's controlling terminal. If nil, there is no +// controlling terminal. +func (tg *ThreadGroup) TTY() *TTY { + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + return tg.tty +} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 2d9251e92..6299a3e2f 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -408,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el start = vaddr } if vaddr < end { + // NOTE(b/37474556): Linux allows out-of-order + // segments, in violation of the spec. ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end) return loadedELF{}, syserror.ENOEXEC } @@ -624,15 +626,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in return loadParsedELF(ctx, m, f, info, 0) } -// loadELF loads f into the Task address space. +// loadELF loads args.File into the Task address space. // // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // // Preconditions: -// * f is an ELF file -func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) { - bin, ac, err := loadInitialELF(ctx, m, fs, f) +// * args.File is an ELF file +func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { + bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { ctx.Infof("Error loading binary: %v", err) return loadedELF{}, nil, err @@ -640,7 +642,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace var interp loadedELF if bin.interpreter != "" { - d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter) + // Even if we do not allow the final link of the script to be + // resolved, the interpreter should still be resolved if it is + // a symlink. + args.ResolveFinal = true + // Refresh the traversal limit. + *args.RemainingTraversals = linux.MaxSymlinkTraversals + args.Filename = bin.interpreter + d, i, err := openPath(ctx, args) if err != nil { ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err @@ -649,7 +658,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace // We don't need the Dirent. d.DecRef() - interp, err = loadInterpreterELF(ctx, m, i, bin) + interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin) if err != nil { ctx.Infof("Error loading interpreter: %v", err) return loadedELF{}, nil, err diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 089d1635b..b03eeb005 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package loader loads a binary into a MemoryManager. +// Package loader loads an executable file into a MemoryManager. package loader import ( @@ -20,6 +20,7 @@ import ( "fmt" "io" "path" + "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +36,54 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LoadArgs holds specifications for an executable file to be loaded. +type LoadArgs struct { + // MemoryManager is the memory manager to load the executable into. + MemoryManager *mm.MemoryManager + + // Mounts is the mount namespace in which to look up Filename. + Mounts *fs.MountNamespace + + // Root is the root directory under which to look up Filename. + Root *fs.Dirent + + // WorkingDirectory is the working directory under which to look up + // Filename. + WorkingDirectory *fs.Dirent + + // RemainingTraversals is the maximum number of symlinks to follow to + // resolve Filename. This counter is passed by reference to keep it + // updated throughout the call stack. + RemainingTraversals *uint + + // ResolveFinal indicates whether the final link of Filename should be + // resolved, if it is a symlink. + ResolveFinal bool + + // Filename is the path for the executable. + Filename string + + // File is an open fs.File object of the executable. If File is not + // nil, then File will be loaded and Filename will be ignored. + File *fs.File + + // CloseOnExec indicates that the executable (or one of its parent + // directories) was opened with O_CLOEXEC. If the executable is an + // interpreter script, then cause an ENOENT error to occur, since the + // script would otherwise be inaccessible to the interpreter. + CloseOnExec bool + + // Argv is the vector of arguments to pass to the executable. + Argv []string + + // Envv is the vector of environment variables to pass to the + // executable. + Envv []string + + // Features specifies the CPU feature set for the executable. + Features *cpuid.FeatureSet +} + // readFull behaves like io.ReadFull for an *fs.File. func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 @@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in return total, nil } -// openPath opens name for loading. +// openPath opens args.Filename and checks that it is valid for loading. // -// openPath returns the fs.Dirent and an *fs.File for name, which is not +// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not // installed in the Task FDTable. The caller takes ownership of both. // -// name must be a readable, executable, regular file. -func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) { - if name == "" { +// args.Filename must be a readable, executable, regular file. +func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) { + if args.Filename == "" { ctx.Infof("cannot open empty name") return nil, nil, syserror.ENOENT } - d, err := mm.FindInode(ctx, root, wd, name, maxTraversals) + var d *fs.Dirent + var err error + if args.ResolveFinal { + d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } else { + d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } if err != nil { return nil, nil, err } - - // Open file will take a reference to Dirent, so destroy this one. + // Defer a DecRef for the sake of failure cases. defer d.DecRef() - return openFile(ctx, nil, d, name) -} + if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) { + return nil, nil, syserror.ELOOP + } -// openFile performs checks on a file to be executed. If provided a *fs.File, -// openFile takes that file's Dirent and performs checks on it. If provided a -// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's -// Inode and performs checks on that. -// -// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership -// of both. -// -// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. -func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { - // file and dirent must not be nil. - if dirent == nil && file == nil { - ctx.Infof("dirent and file cannot both be nil.") - return nil, nil, syserror.ENOENT + if err := checkPermission(ctx, d); err != nil { + return nil, nil, err } - if file != nil { - dirent = file.Dirent + // If they claim it's a directory, then make sure. + // + // N.B. we reject directories below, but we must first reject + // non-directories passed as directories. + if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) { + return nil, nil, syserror.ENOTDIR } - // Perform permissions checks on the file. - if err := checkFile(ctx, dirent, name); err != nil { + if err := checkIsRegularFile(ctx, d, args.Filename); err != nil { return nil, nil, err } - if file == nil { - var ferr error - if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { - return nil, nil, ferr - } - } else { - // GetFile takes a reference to the created file, so make one in the case - // that the file reference already existed. - file.IncRef() + f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return nil, nil, err + } + // Defer a DecRef for the sake of failure cases. + defer f.DecRef() + + if err := checkPread(ctx, f, args.Filename); err != nil { + return nil, nil, err + } + + d.IncRef() + f.IncRef() + return d, f, err +} + +// checkFile performs checks on a file to be executed. +func checkFile(ctx context.Context, f *fs.File, filename string) error { + if err := checkPermission(ctx, f.Dirent); err != nil { + return err } - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) - return nil, nil, syserror.EACCES + if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil { + return err } - // Grab reference for caller. - dirent.IncRef() - return dirent, file, nil + return checkPread(ctx, f, filename) } -// checkFile performs file permissions checks for binaries called in openPath -// and openFile -func checkFile(ctx context.Context, d *fs.Dirent, name string) error { +// checkPermission checks whether the file is readable and executable. +func checkPermission(ctx context.Context, d *fs.Dirent) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error { Read: true, Execute: true, } - if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return err - } + return d.Inode.CheckPermission(ctx, perms) +} - // If they claim it's a directory, then make sure. - // - // N.B. we reject directories below, but we must first reject - // non-directories passed as directories. - if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR +// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc. +func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error { + attr := d.Inode.StableAttr + if !fs.IsRegular(attr) { + ctx.Infof("%s is not regular: %v", filename, attr) + return syserror.EACCES } + return nil +} - // No exec-ing directories, pipes, etc! - if !fs.IsRegular(d.Inode.StableAttr) { - ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) +// checkPread checks whether we can read the file at arbitrary offsets. +func checkPread(ctx context.Context, f *fs.File, filename string) error { + if !f.Flags().Pread { + ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags()) return syserror.EACCES } - return nil - } // allocStack allocates and maps a stack in to any available part of the address space. @@ -173,45 +224,49 @@ const ( maxLoaderAttempts = 6 ) -// loadBinary loads a binary that is pointed to by "file". If nil, the path -// "filename" is resolved and loaded. +// loadExecutable loads an executable that is pointed to by args.File. If nil, +// the path args.Filename is resolved and loaded. If the executable is an +// interpreter script rather than an ELF, the binary of the corresponding +// interpreter will be loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file -// * Possibly updated argv -func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +// * Possibly updated args.Argv +func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { var ( d *fs.Dirent - f *fs.File err error ) - if passedFile == nil { - d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) - + if args.File == nil { + d, args.File, err = openPath(ctx, args) + // We will return d in the successful case, but defer a DecRef for the + // sake of intermediate loops and failure cases. + if d != nil { + defer d.DecRef() + } + if args.File != nil { + defer args.File.DecRef() + } } else { - d, f, err = openFile(ctx, passedFile, nil, "") - // Set to nil in case we loop on a Interpreter Script. - passedFile = nil + d = args.File.Dirent + d.IncRef() + defer d.DecRef() + err = checkFile(ctx, args.File, args.Filename) } - if err != nil { - ctx.Infof("Error opening %s: %v", filename, err) + ctx.Infof("Error opening %s: %v", args.Filename, err) return loadedELF{}, nil, nil, nil, err } - defer f.DecRef() - // We will return d in the successful case, but defer a DecRef - // for intermediate loops and failure cases. - defer d.DecRef() // Check the header. Is this an ELF or interpreter script? var hdr [4]uint8 // N.B. We assume that reading from a regular file cannot block. - _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0) - // Allow unexpected EOF, as a valid executable could be only three - // bytes (e.g., #!a). + _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0) + // Allow unexpected EOF, as a valid executable could be only three bytes + // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { err = syserror.ENOEXEC @@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) + loaded, ac, err := loadELF(ctx, args) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err } // An ELF is always terminal. Hold on to d. d.IncRef() - return loaded, ac, d, argv, err + return loaded, ac, d, args.Argv, err case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): - newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv) + if args.CloseOnExec { + return loadedELF{}, nil, nil, nil, syserror.ENOENT + } + args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { ctx.Infof("Error loading interpreter script: %v", err) return loadedELF{}, nil, nil, nil, err } - filename = newpath - argv = newargv + // Refresh the traversal limit for the interpreter. + *args.RemainingTraversals = linux.MaxSymlinkTraversals default: ctx.Infof("Unknown magic: %v", hdr) return loadedELF{}, nil, nil, nil, syserror.ENOEXEC } + // Set to nil in case we loop on a Interpreter Script. + args.File = nil } return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads "file" into a MemoryManager. If file is nil, the path "filename" -// is resolved and loaded instead. +// Load loads args.File into a MemoryManager. If args.File is nil, the path +// args.Filename is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { - // Load the binary itself. - loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) +func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { + // Load the executable itself. + loaded, ac, d, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer d.DecRef() // Load the VDSO. - vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded) + vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the - // binary. Userspace can assume that the remainer of the page after + // executable. Userspace can assume that the remainer of the page after // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) } - m.BrkSetup(ctx, e) + args.MemoryManager.BrkSetup(ctx, e) // Allocate our stack. - stack, err := allocStack(ctx, m, ac) + stack, err := allocStack(ctx, args.MemoryManager, ac) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux()) } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(filename) + execfn, err := stack.Push(args.Filename) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } @@ -319,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r }...) auxv = append(auxv, extraAuxv...) - sl, err := stack.Load(argv, envv, auxv) + sl, err := stack.Load(newArgv, args.Envv, auxv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux()) } + m := args.MemoryManager m.SetArgvStart(sl.ArgvStart) m.SetArgvEnd(sl.ArgvEnd) m.SetEnvvStart(sl.EnvvStart) @@ -334,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) - name := path.Base(filename) + name := path.Base(args.Filename) if len(name) > linux.TASK_COMM_LEN-1 { name = name[:linux.TASK_COMM_LEN-1] } diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index ada28aea3..df8a81907 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -268,6 +268,8 @@ func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, er // some applications may not be able to handle multiple [vdso] // hints. vdso: mm.NewSpecialMappable("", mfp, vdso), + os: info.os, + arch: info.arch, phdrs: info.phdrs, }, nil } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 3ef84245b..112794e9c 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -41,7 +41,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", - "//pkg/refs", "//pkg/sentry/context", "//pkg/sentry/platform", "//pkg/sentry/usermem", diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 03b99aaea..16a722a13 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -18,7 +18,6 @@ package memmap import ( "fmt" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" @@ -235,8 +234,11 @@ type InvalidateOpts struct { // coincidental; fs.File implements MappingIdentity, and some // fs.InodeOperations implement Mappable.) type MappingIdentity interface { - // MappingIdentity is reference-counted. - refs.RefCounter + // IncRef increments the MappingIdentity's reference count. + IncRef() + + // DecRef decrements the MappingIdentity's reference count. + DecRef() // MappedName returns the application-visible name shown in // /proc/[pid]/maps. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a804b8b5c..839931f67 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,9 +118,9 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/tcpip/buffer", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index f350e0109..58a5c186d 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -44,7 +44,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // MemoryManager implements a virtual address space. @@ -82,7 +82,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu gvsync.DowngradableRWMutex `state:"nosave"` + mappingMu syncutil.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +125,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu gvsync.DowngradableRWMutex `state:"nosave"` + activeMu syncutil.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index 1effc7735..aafce1d00 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -16,6 +16,7 @@ package pgalloc import ( "bytes" + "context" "fmt" "io" "runtime" @@ -29,7 +30,7 @@ import ( ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // Save metadata. - if err := state.Save(w, &f.fileSize, nil); err != nil { + if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { return err } - if err := state.Save(w, &f.usage, nil); err != nil { + if err := state.Save(ctx, w, &f.usage, nil); err != nil { return err } @@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { // Load metadata. - if err := state.Load(r, &f.fileSize, nil); err != nil { + if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(r, &f.usage, nil); err != nil { + if err := state.Load(ctx, r, &f.usage, nil); err != nil { return err } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 31fa48ec5..f3afd98da 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -12,19 +12,30 @@ go_library( "bluepill_amd64.go", "bluepill_amd64.s", "bluepill_amd64_unsafe.go", + "bluepill_arm64.go", + "bluepill_arm64.s", + "bluepill_arm64_unsafe.go", "bluepill_fault.go", "bluepill_unsafe.go", "context.go", - "filters.go", + "filters_amd64.go", + "filters_arm64.go", "kvm.go", "kvm_amd64.go", "kvm_amd64_unsafe.go", + "kvm_arm64.go", + "kvm_arm64_unsafe.go", "kvm_const.go", + "kvm_const_arm64.go", "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", + "machine_arm64_unsafe.go", "machine_unsafe.go", "physical_map.go", + "physical_map_amd64.go", + "physical_map_arm64.go", "virtual_map.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 043de51b3..30dbb74d6 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -20,6 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" ) @@ -36,6 +37,18 @@ func sighandler() func dieTrampoline() var ( + // bounceSignal is the signal used for bouncing KVM. + // + // We use SIGCHLD because it is not masked by the runtime, and + // it will be ignored properly by other parts of the kernel. + bounceSignal = syscall.SIGCHLD + + // bounceSignalMask has only bounceSignal set. + bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) + + // bounce is the interrupt vector used to return to the kernel. + bounce = uint32(ring0.VirtualizationException) + // savedHandler is a pointer to the previous handler. // // This is called by bluepillHandler. @@ -45,6 +58,13 @@ var ( dieTrampolineAddr uintptr ) +// redpill invokes a syscall with -1. +// +//go:nosplit +func redpill() { + syscall.RawSyscall(^uintptr(0), 0, 0, 0) +} + // dieHandler is called by dieTrampoline. // //go:nosplit @@ -73,8 +93,8 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { - panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err)) + if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index 421c88220..133c2203d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -24,26 +24,10 @@ import ( ) var ( - // bounceSignal is the signal used for bouncing KVM. - // - // We use SIGCHLD because it is not masked by the runtime, and - // it will be ignored properly by other parts of the kernel. - bounceSignal = syscall.SIGCHLD - - // bounceSignalMask has only bounceSignal set. - bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1)) - - // bounce is the interrupt vector used to return to the kernel. - bounce = uint32(ring0.VirtualizationException) + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGSEGV ) -// redpill on amd64 invokes a syscall with -1. -// -//go:nosplit -func redpill() { - syscall.RawSyscall(^uintptr(0), 0, 0, 0) -} - // bluepillArchEnter is called during bluepillEnter. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 9d8af143e..a63a6a071 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -23,13 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) -// bluepillArchContext returns the arch-specific context. -// -//go:nosplit -func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { - return &((*arch.UContext64)(context).MContext) -} - // dieArchSetup initializes the state for dieTrampoline. // // The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go new file mode 100644 index 000000000..552341721 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -0,0 +1,79 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" +) + +var ( + // The action for bluepillSignal is changed by sigaction(). + bluepillSignal = syscall.SIGILL +) + +// bluepillArchEnter is called during bluepillEnter. +// +//go:nosplit +func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { + c = vCPUPtr(uintptr(context.Regs[8])) + regs := c.CPU.Registers() + regs.Regs = context.Regs + regs.Sp = context.Sp + regs.Pc = context.Pc + regs.Pstate = context.Pstate + regs.Pstate &^= uint64(ring0.KernelFlagsClear) + regs.Pstate |= ring0.KernelFlagsSet + return +} + +// bluepillArchExit is called during bluepillEnter. +// +//go:nosplit +func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { + regs := c.CPU.Registers() + context.Regs = regs.Regs + context.Sp = regs.Sp + context.Pc = regs.Pc + context.Pstate = regs.Pstate + context.Pstate &^= uint64(ring0.UserFlagsClear) + context.Pstate |= ring0.UserFlagsSet +} + +// KernelSyscall handles kernel syscalls. +// +//go:nosplit +func (c *vCPU) KernelSyscall() { + regs := c.Registers() + if regs.Regs[8] != ^uint64(0) { + regs.Pc -= 4 // Rewind. + } + ring0.Halt() +} + +// KernelException handles kernel exceptions. +// +//go:nosplit +func (c *vCPU) KernelException(vector ring0.Vector) { + regs := c.Registers() + if vector == ring0.Vector(bounce) { + regs.Pc = 0 + } + ring0.Halt() +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s new file mode 100644 index 000000000..c61700892 --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -0,0 +1,87 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// VCPU_CPU is the location of the CPU in the vCPU struct. +// +// This is guaranteed to be zero. +#define VCPU_CPU 0x0 + +// CPU_SELF is the self reference in ring0's percpu. +// +// This is guaranteed to be zero. +#define CPU_SELF 0x0 + +// Context offsets. +// +// Only limited use of the context is done in the assembly stub below, most is +// done in the Go handlers. +#define SIGINFO_SIGNO 0x0 +#define CONTEXT_PC 0x1B8 +#define CONTEXT_R0 0xB8 + +// See bluepill.go. +TEXT ·bluepill(SB),NOSPLIT,$0 +begin: + MOVD vcpu+0(FP), R8 + MOVD $VCPU_CPU(R8), R9 + ORR $0xffff000000000000, R9, R9 + // Trigger sigill. + // In ring0.Start(), the value of R8 will be stored into tpidr_el1. + // When the context was loaded into vcpu successfully, + // we will check if the value of R10 and R9 are the same. + WORD $0xd538d08a // MRS TPIDR_EL1, R10 +check_vcpu: + CMP R10, R9 + BEQ right_vCPU +wrong_vcpu: + CALL ·redpill(SB) + B begin +right_vCPU: + RET + +// sighandler: see bluepill.go for documentation. +// +// The arguments are the following: +// +// R0 - The signal number. +// R1 - Pointer to siginfo_t structure. +// R2 - Pointer to ucontext structure. +// +TEXT ·sighandler(SB),NOSPLIT,$0 + // si_signo should be sigill. + MOVD SIGINFO_SIGNO(R1), R7 + CMPW $4, R7 + BNE fallback + + MOVD CONTEXT_PC(R2), R7 + CMPW $0, R7 + BEQ fallback + + MOVD R2, 8(RSP) + BL ·bluepillHandler(SB) // Call the handler. + + RET + +fallback: + // Jump to the previous signal handler. + MOVD ·savedHandler(SB), R7 + B (R7) + +// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. +TEXT ·dieTrampoline(SB),NOSPLIT,$0 + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. + MOVD R9, 8(RSP) + BL ·dieHandler(SB) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go new file mode 100644 index 000000000..e5fac0d6a --- /dev/null +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" +) + +//go:nosplit +func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { + // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. +} diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 7e8e9f42a..9add7c944 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -23,6 +23,8 @@ import ( "sync/atomic" "syscall" "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" ) //go:linkname throw runtime.throw @@ -49,6 +51,13 @@ func uintptrValue(addr *byte) uintptr { return (uintptr)(unsafe.Pointer(addr)) } +// bluepillArchContext returns the UContext64. +// +//go:nosplit +func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 { + return &((*arch.UContext64)(context).MContext) +} + // bluepillHandler is called from the signal stub. // // The world may be stopped while this is executing, and it executes on the @@ -80,13 +89,17 @@ func bluepillHandler(context unsafe.Pointer) { // interrupted KVM. Since we're in a signal handler // currently, all signals are masked and the signal // must have been delivered directly to this thread. + timeout := syscall.Timespec{} sig, _, errno := syscall.RawSyscall6( syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), - 0, // siginfo. - 0, // timeout. - 8, // sigset size. + 0, // siginfo. + uintptr(unsafe.Pointer(&timeout)), // timeout. + 8, // sigset size. 0, 0) + if errno == syscall.EAGAIN { + continue + } if errno != 0 { throw("error waiting for pending signal") } @@ -162,7 +175,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/filters.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..7d949f1dd 100644 --- a/pkg/sentry/platform/kvm/filters.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go new file mode 100644 index 000000000..9245d07c2 --- /dev/null +++ b/pkg/sentry/platform/kvm/filters_arm64.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +// SyscallFilters returns syscalls made exclusively by the KVM platform. +func (*KVM) SyscallFilters() seccomp.SyscallRules { + return seccomp.SyscallRules{ + syscall.SYS_IOCTL: {}, + syscall.SYS_MMAP: {}, + syscall.SYS_RT_SIGSUSPEND: {}, + syscall.SYS_RT_SIGTIMEDWAIT: {}, + 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host. + } +} diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ee4cd2f4d..f2c2c059e 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -21,7 +21,6 @@ import ( "sync" "syscall" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" @@ -56,9 +55,7 @@ func New(deviceFile *os.File) (*KVM, error) { // Ensure global initialization is done. globalOnce.Do(func() { - physicalInit() - globalErr = updateSystemValues(int(fd)) - ring0.Init(cpuid.HostFeatureSet()) + globalErr = updateGlobalOnce(int(fd)) }) if globalErr != nil { return nil, globalErr diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index 5d8ef4761..c5a6f9c7d 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -17,6 +17,7 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) @@ -211,3 +212,11 @@ type cpuidEntries struct { _ uint32 entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry } + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + ring0.Init(cpuid.HostFeatureSet()) + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go new file mode 100644 index 000000000..2319c86d3 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "syscall" +) + +// userMemoryRegion is a region of physical memory. +// +// This mirrors kvm_memory_region. +type userMemoryRegion struct { + slot uint32 + flags uint32 + guestPhysAddr uint64 + memorySize uint64 + userspaceAddr uint64 +} + +type kvmOneReg struct { + id uint64 + addr uint64 +} + +const KVM_NR_SPSR = 5 + +type userFpsimdState struct { + vregs [64]uint64 + fpsr uint32 + fpcr uint32 + reserved [2]uint32 +} + +type userRegs struct { + Regs syscall.PtraceRegs + sp_el1 uint64 + elr_el1 uint64 + spsr [KVM_NR_SPSR]uint64 + fpRegs userFpsimdState +} + +// runData is the run structure. This may be mapped for synchronous register +// access (although that doesn't appear to be supported by my kernel at least). +// +// This mirrors kvm_run. +type runData struct { + requestInterruptWindow uint8 + _ [7]uint8 + + exitReason uint32 + readyForInterruptInjection uint8 + ifFlag uint8 + _ [2]uint8 + + cr8 uint64 + apicBase uint64 + + // This is the union data for exits. Interpretation depends entirely on + // the exitReason above (see vCPU code for more information). + data [32]uint64 +} + +// updateGlobalOnce does global initialization. It has to be called only once. +func updateGlobalOnce(fd int) error { + physicalInit() + err := updateSystemValues(int(fd)) + updateVectorTable() + return err +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go new file mode 100644 index 000000000..6531bae1d --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "syscall" +) + +var ( + runDataSize int +) + +func updateSystemValues(fd int) error { + // Extract the mmap size. + sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0) + if errno != 0 { + return fmt.Errorf("getting VCPU mmap size: %v", errno) + } + // Save the data. + runDataSize = int(sz) + + // Success. + return nil +} diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..1d5c77ff4 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -49,11 +49,13 @@ const ( _KVM_EXIT_SHUTDOWN = 0x8 _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 + _KVM_EXIT_SYSTEM_EVENT = 0x18 ) // KVM capability options. const ( - _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 ) // KVM limits. @@ -62,3 +64,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go new file mode 100644 index 000000000..5a74c6e36 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -0,0 +1,132 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// KVM ioctls for Arm64. +const ( + _KVM_GET_ONE_REG = 0x4010aeab + _KVM_SET_ONE_REG = 0x4010aeac + + _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf + _KVM_ARM_VCPU_INIT = 0x4020aeae + _KVM_ARM64_REGS_PSTATE = 0x6030000000100042 + _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044 + _KVM_ARM64_REGS_R0 = 0x6030000000100000 + _KVM_ARM64_REGS_R1 = 0x6030000000100002 + _KVM_ARM64_REGS_R2 = 0x6030000000100004 + _KVM_ARM64_REGS_R3 = 0x6030000000100006 + _KVM_ARM64_REGS_R8 = 0x6030000000100010 + _KVM_ARM64_REGS_R18 = 0x6030000000100024 + _KVM_ARM64_REGS_PC = 0x6030000000100040 + _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510 + _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102 + _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100 + _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101 + _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 + _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 + _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 +) + +// Arm64: Architectural Feature Access Control Register EL1. +const ( + _FPEN_NOTRAP = 0x3 + _FPEN_SHIFT = 0x20 +) + +// Arm64: System Control Register EL1. +const ( + _SCTLR_M = 1 << 0 + _SCTLR_C = 1 << 2 + _SCTLR_I = 1 << 12 +) + +// Arm64: Translation Control Register EL1. +const ( + _TCR_IPS_40BITS = 2 << 32 // PA=40 + _TCR_IPS_48BITS = 5 << 32 // PA=48 + + _TCR_T0SZ_OFFSET = 0 + _TCR_T1SZ_OFFSET = 16 + _TCR_IRGN0_SHIFT = 8 + _TCR_IRGN1_SHIFT = 24 + _TCR_ORGN0_SHIFT = 10 + _TCR_ORGN1_SHIFT = 26 + _TCR_SH0_SHIFT = 12 + _TCR_SH1_SHIFT = 28 + _TCR_TG0_SHIFT = 14 + _TCR_TG1_SHIFT = 30 + + _TCR_T0SZ_VA48 = 64 - 48 // VA=48 + _TCR_T1SZ_VA48 = 64 - 48 // VA=48 + + _TCR_ASID16 = 1 << 36 + _TCR_TBI0 = 1 << 37 + + _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET) + + _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K + _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K + + _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT + + _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K + + _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT + _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT + _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA + + _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT + _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT + + _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA + + _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT) + + _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA +) + +// Arm64: Memory Attribute Indirection Register EL1. +const ( + _MT_DEVICE_nGnRnE = 0 + _MT_DEVICE_nGnRE = 1 + _MT_DEVICE_GRE = 2 + _MT_NORMAL_NC = 3 + _MT_NORMAL = 4 + _MT_NORMAL_WT = 5 + _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8) +) + +const ( + _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state + _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2 +) + +// Arm64: Exception Syndrome Register EL1. +const ( + _ESR_ELx_FSC = 0x3F + + _ESR_SEGV_MAPERR_L0 = 0x4 + _ESR_SEGV_MAPERR_L1 = 0x5 + _ESR_SEGV_MAPERR_L2 = 0x6 + _ESR_SEGV_MAPERR_L3 = 0x7 + + _ESR_SEGV_ACCERR_L1 = 0x9 + _ESR_SEGV_ACCERR_L2 = 0xa + _ESR_SEGV_ACCERR_L3 = 0xb + + _ESR_SEGV_PEMERR_L1 = 0xd + _ESR_SEGV_PEMERR_L2 = 0xe + _ESR_SEGV_PEMERR_L3 = 0xf +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..7156c245f 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. @@ -159,3 +135,43 @@ func (c *vCPU) setSignalMask() error { } return nil } + +// setUserRegisters sets user registers in the vCPU. +func (c *vCPU) setUserRegisters(uregs *userRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) + } + return nil +} + +// getUserRegisters reloads user registers in the vCPU. +// +// This is safe to call from a nosplit context. +// +//go:nosplit +func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_REGS, + uintptr(unsafe.Pointer(uregs))); errno != 0 { + return errno + } + return 0 +} + +// setSystemRegisters sets system registers. +func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return fmt.Errorf("error setting system registers: %v", errno) + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..7ae47f291 --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,183 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +type vCPUArchState struct { + // PCIDs is the set of PCIDs for this vCPU. + // + // This starts above fixedKernelPCID. + PCIDs *pagetables.PCIDs +} + +const ( + // fixedKernelPCID is a fixed kernel PCID used for the kernel page + // tables. We must start allocating user PCIDs above this in order to + // avoid any conflict (see below). + fixedKernelPCID = 1 + + // poolPCIDs is the number of PCIDs to record in the database. As this + // grows, assignment can take longer, since it is a simple linear scan. + // Beyond a relatively small number, there are likely few perform + // benefits, since the TLB has likely long since lost any translations + // from more than a few PCIDs past. + poolPCIDs = 8 +) + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} + +// dropPageTables drops cached page table entries. +func (m *machine) dropPageTables(pt *pagetables.PageTables) { + m.mu.Lock() + defer m.mu.Unlock() + + // Clear from all PCIDs. + for _, c := range m.vCPUs { + c.PCIDs.Drop(pt) + } +} + +// nonCanonical generates a canonical address return. +// +//go:nosplit +func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + *info = arch.SignalInfo{ + Signo: signal, + Code: arch.SignalInfoKernel, + } + info.SetAddr(addr) // Include address. + return usermem.NoAccess, platform.ErrContextSignal +} + +// fault generates an appropriate fault return. +// +//go:nosplit +func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) { + faultAddr := c.GetFaultAddr() + code, user := c.ErrorCode() + + // Reset the pointed SignalInfo. + *info = arch.SignalInfo{Signo: signal} + info.SetAddr(uint64(faultAddr)) + + read := true + write := false + execute := true + + ret := code & _ESR_ELx_FSC + switch ret { + case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3: + info.Code = 1 //SEGV_MAPERR + read = false + write = true + execute = false + case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3: + info.Code = 2 // SEGV_ACCERR. + read = true + write = false + execute = false + default: + info.Code = 2 + } + + if !user { + read = true + write = false + execute = true + + } + accessType := usermem.AccessType{ + Read: read, + Write: write, + Execute: execute, + } + + return accessType, platform.ErrContextSignal +} + +// retryInGuest runs the given function in guest mode. +// +// If the function does not complete in guest mode (due to execution of a +// system call due to a GC stall, for example), then it will be retried. The +// given function must be idempotent as a result of the retry mechanism. +func (m *machine) retryInGuest(fn func()) { + c := m.Get() + defer m.Put(c) + for { + c.ClearErrorCode() // See below. + bluepill(c) // Force guest mode. + fn() // Execute the given function. + _, user := c.ErrorCode() + if user { + // If user is set, then we haven't bailed back to host + // mode via a kernel exception or system call. We + // consider the full function to have executed in guest + // mode and we can return. + break + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go new file mode 100644 index 000000000..3f2f97a6b --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package kvm + +import ( + "fmt" + "reflect" + "sync/atomic" + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: 0, + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + +type kvmVcpuInit struct { + target uint32 + features [7]uint32 +} + +var vcpuInit kvmVcpuInit + +// initArchState initializes architecture-specific state. +func (m *machine) initArchState() error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_ARM_PREFERRED_TARGET, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) + } + return nil +} + +func getPageWithReflect(p uintptr) []byte { + return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()] +} + +// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. +// +// According to the design documentation of Arm64, +// the start address of exception vector table should be 11-bits aligned. +// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S +// But, we can't align a function's start address to a specific address by using golang. +// We have raised this question in golang community: +// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I +// This function will be removed when golang supports this feature. +// +// There are 2 jobs were implemented in this function: +// 1, move the start address of exception vector table into the specific address. +// 2, modify the offset of each instruction. +func updateVectorTable() { + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + page := getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil { + panic(err) + } + + // Move exception-vector-table into the specific address. + var entry *uint32 + var entryFrom *uint32 + for i := 1; i <= 0x800; i++ { + entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i))) + entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i))) + *entry = *entryFrom + } + + // The offset from the address of each unconditionally branch is changed. + // We should modify the offset of each instruction. + nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780} + for _, num := range nums { + entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num))) + *entry = *entry - (uint32)(offset/4) + } + + page = getPageWithReflect(toLocation) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } + + page = getPageWithReflect(toLocation + 4096) + if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil { + panic(err) + } +} + +// initArchState initializes architecture-specific state. +func (c *vCPU) initArchState() error { + var ( + reg kvmOneReg + data uint64 + regGet kvmOneReg + dataGet uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer()) + + vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_ARM_VCPU_INIT, + uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { + panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno)) + } + + // cpacr_el1 + reg.id = _KVM_ARM64_REGS_CPACR_EL1 + data = (_FPEN_NOTRAP << _FPEN_SHIFT) + if err := c.setOneRegister(®); err != nil { + return err + } + + // sctlr_el1 + regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.getOneRegister(®Get); err != nil { + return err + } + + dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) + data = dataGet + reg.id = _KVM_ARM64_REGS_SCTLR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // tcr_el1 + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + reg.id = _KVM_ARM64_REGS_TCR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // mair_el1 + data = _MT_EL1_INIT + reg.id = _KVM_ARM64_REGS_MAIR_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // ttbr0_el1 + data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR0_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + c.SetTtbr0Kvm(uintptr(data)) + + // ttbr1_el1 + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + + reg.id = _KVM_ARM64_REGS_TTBR1_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // sp_el1 + data = c.CPU.StackTop() + reg.id = _KVM_ARM64_REGS_SP_EL1 + if err := c.setOneRegister(®); err != nil { + return err + } + + // pc + reg.id = _KVM_ARM64_REGS_PC + data = uint64(reflect.ValueOf(ring0.Start).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // r8 + reg.id = _KVM_ARM64_REGS_R8 + data = uint64(reflect.ValueOf(&c.CPU).Pointer()) + if err := c.setOneRegister(®); err != nil { + return err + } + + // vbar_el1 + reg.id = _KVM_ARM64_REGS_VBAR_EL1 + + fromLocation := reflect.ValueOf(ring0.Vectors).Pointer() + offset := fromLocation & (1<<11 - 1) + if offset != 0 { + offset = 1<<11 - offset + } + + toLocation := fromLocation + offset + data = uint64(ring0.KernelStartAddress | toLocation) + if err := c.setOneRegister(®); err != nil { + return err + } + + data = ring0.PsrDefaultSet | ring0.KernelFlagsSet + reg.id = _KVM_ARM64_REGS_PSTATE + if err := c.setOneRegister(®); err != nil { + return err + } + + return nil +} + +//go:nosplit +func (c *vCPU) loadSegments(tid uint64) { + // TODO(gvisor.dev/issue/1238): TLS is not supported. + // Get TLS from tpidr_el0. + atomic.StoreUint64(&c.tid, tid) +} + +func (c *vCPU) setOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +func (c *vCPU) getOneRegister(reg *kvmOneReg) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_ONE_REG, + uintptr(unsafe.Pointer(reg))); errno != 0 { + return fmt.Errorf("error setting one register: %v", errno) + } + return nil +} + +// setCPUID sets the CPUID to be used by the guest. +func (c *vCPU) setCPUID() error { + return nil +} + +// setSystemTime sets the TSC for the vCPU. +func (c *vCPU) setSystemTime() error { + return nil +} + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + + return nil +} + +// SwitchToUser unpacks architectural-details. +func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { + // Check for canonical addresses. + if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { + return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) + } else if !ring0.IsCanonical(regs.Sp) { + return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + } + + var vector ring0.Vector + ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) + c.SetTtbr0App(uintptr(ttbr0App)) + + // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64. + // The Arm64 user-mode execution state consists of: + // x0-x30 + // PC, SP, PSTATE + // V0-V31: 32 128-bit registers for floating point, and simd + // FPSR + // TPIDR_EL0, used for TLS + appRegs := switchOpts.Registers + c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs))) + + entersyscall() + bluepill(c) + vector = c.CPU.SwitchToUser(switchOpts) + exitsyscall() + + switch vector { + case ring0.Syscall: + // Fast path: system call executed. + return usermem.NoAccess, nil + + case ring0.PageFault: + return c.fault(int32(syscall.SIGSEGV), info) + case 0xaa: + return usermem.NoAccess, nil + default: + return usermem.NoAccess, platform.ErrContextSignal + } + +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 405e00292..f04be2ab5 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( @@ -63,46 +87,6 @@ func unmapRunData(r *runData) error { return nil } -// setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) - } - return nil -} - -// getUserRegisters reloads user registers in the vCPU. -// -// This is safe to call from a nosplit context. -// -//go:nosplit -func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_GET_REGS, - uintptr(unsafe.Pointer(uregs))); errno != 0 { - return errno - } - return 0 -} - -// setSystemRegisters sets system registers. -func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_SREGS, - uintptr(unsafe.Pointer(sregs))); errno != 0 { - return fmt.Errorf("error setting system registers: %v", errno) - } - return nil -} - // atomicAddressSpace is an atomic address space pointer. type atomicAddressSpace struct { pointer unsafe.Pointer diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index 586e91bb2..91de5dab1 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -24,15 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -const ( - // reservedMemory is a chunk of physical memory reserved starting at - // physical address zero. There are some special pages in this region, - // so we just call the whole thing off. - // - // Other architectures may define this to be zero. - reservedMemory = 0x100000000 -) - type region struct { virtual uintptr length uintptr @@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) { // We can cut vSize in half, because the kernel will be using the top // half and we ignore it while constructing mappings. It's as if we've // already excluded half the possible addresses. - vSize := uintptr(1) << ring0.VirtualAddressBits() - vSize = vSize >> 1 + vSize := ring0.UserspaceSize // We exclude reservedMemory below from our physical memory size, so it // needs to be dropped here as well. Otherwise, we could end up with diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go new file mode 100644 index 000000000..c5adfb577 --- /dev/null +++ b/pkg/sentry/platform/kvm/physical_map_amd64.go @@ -0,0 +1,22 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +const ( + // reservedMemory is a chunk of physical memory reserved starting at + // physical address zero. There are some special pages in this region, + // so we just call the whole thing off. + reservedMemory = 0x100000000 +) diff --git a/pkg/sentry/platform/kvm/physical_map_arm64.go b/pkg/sentry/platform/kvm/physical_map_arm64.go new file mode 100644 index 000000000..4d8561453 --- /dev/null +++ b/pkg/sentry/platform/kvm/physical_map_arm64.go @@ -0,0 +1,19 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +const ( + reservedMemory = 0 +) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 2cd28b2d2..0bebee852 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -50,6 +50,21 @@ TEXT ·SpinLoop(SB),NOSPLIT,$0 start: B start +TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8 + NO_LOCAL_POINTERS + FMOVD $(9.9), F0 + MOVD $SYS_GETPID, R8 // getpid + SVC + FMOVD $(9.9), F1 + FCMPD F0, F1 + BNE isNaN + MOVD $1, R0 + MOVD R0, ret+0(FP) + RET +isNaN: + MOVD $0, ret+0(FP) + RET + // MVN: bitwise logical NOT // This case simulates an application that modified R0-R30. #define TWIDDLE_REGS() \ diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index ebcc8c098..0df8cfa0f 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/procid", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/hostcpu", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 9f0ecfbe4..ddb1f41e3 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -327,6 +327,20 @@ func (t *thread) dumpAndPanic(message string) { panic(message) } +func (t *thread) unexpectedStubExit() { + msg, err := t.getEventMessage() + status := syscall.WaitStatus(msg) + if status.Signaled() && status.Signal() == syscall.SIGKILL { + // SIGKILL can be only sent by an user or OOM-killer. In both + // these cases, we don't need to panic. There is no reasons to + // think that something wrong in gVisor. + log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid) + pid := os.Getpid() + syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL)) + } + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) +} + // wait waits for a stop event. // // Precondition: outcome is a valid waitOutcome. @@ -355,8 +369,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - msg, err := t.getEventMessage() - t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) + t.unexpectedStubExit() } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 4649a94a7..606dc2b1d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -21,6 +21,8 @@ import ( "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -143,3 +145,49 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.R15 = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Rip = signalInfo.Addr() + regs.Rsp -= 8 + } +} + +// enableCpuidFault enables cpuid-faulting. +// +// This may fail on older kernels or hardware, so we just disregard the result. +// Host CPUID will be enabled. +// +// This is safe to call in an afterFork context. +// +//go:nosplit +func enableCpuidFault() { + syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return append(rules, seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }) +} diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index bec884ba5..62a686ee7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -17,8 +17,12 @@ package ptrace import ( + "fmt" + "strings" "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" ) @@ -37,7 +41,7 @@ const ( // resetSysemuRegs sets up emulation registers. // // This should be called prior to calling sysemu. -func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) { +func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) { } // createSyscallRegs sets up syscall registers. @@ -124,3 +128,36 @@ func (t *thread) adjustInitRegsRip() { func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) { initregs.Regs[7] = uint64(ppid) } + +// patchSignalInfo patches the signal info to account for hitting the seccomp +// filters from vsyscall emulation, specified below. We allow for SIGSYS as a +// synchronous trap, but patch the structure to appear like a SIGSEGV with the +// Rip as the faulting address. +// +// Note that this should only be called after verifying that the signalInfo has +// been generated by the kernel. +func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { + if linux.Signal(signalInfo.Signo) == linux.SIGSYS { + signalInfo.Signo = int32(linux.SIGSEGV) + + // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered + // with the si_call_addr field pointing to the current RIP. This field + // aligns with the si_addr field for a SIGSEGV, so we don't need to touch + // anything there. We do need to unwind emulation however, so we set the + // instruction pointer to the faulting value, and "unpop" the stack. + regs.Pc = signalInfo.Addr() + regs.Sp -= 8 + } +} + +// Noop on arm64. +// +//go:nosplit +func enableCpuidFault() { +} + +// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. +// Ref attachedThread() for more detail. +func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { + return rules +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index c075b5f91..cf13ea5e4 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -20,6 +20,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -77,27 +78,6 @@ func probeSeccomp() bool { } } -// patchSignalInfo patches the signal info to account for hitting the seccomp -// filters from vsyscall emulation, specified below. We allow for SIGSYS as a -// synchronous trap, but patch the structure to appear like a SIGSEGV with the -// Rip as the faulting address. -// -// Note that this should only be called after verifying that the signalInfo has -// been generated by the kernel. -func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) { - if linux.Signal(signalInfo.Signo) == linux.SIGSYS { - signalInfo.Signo = int32(linux.SIGSEGV) - - // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered - // with the si_call_addr field pointing to the current RIP. This field - // aligns with the si_addr field for a SIGSEGV, so we don't need to touch - // anything there. We do need to unwind emulation however, so we set the - // instruction pointer to the faulting value, and "unpop" the stack. - regs.Rip = signalInfo.Addr() - regs.Rsp -= 8 - } -} - // createStub creates a fresh stub processes. // // Precondition: the runtime OS thread must be locked. @@ -129,6 +109,9 @@ func createStub() (*thread, error) { // transitively) will be killed as well. It's simply not possible to // safely handle a single stub getting killed: the exact state of // execution is unknown and not recoverable. + // + // In addition, we set the PTRACE_O_TRACEEXIT option to log more + // information about a stub process when it receives a fatal signal. return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction) } @@ -146,7 +129,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_GETTIMEOFDAY: {}, syscall.SYS_TIME: {}, - 309: {}, // SYS_GETCPU. + unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. }, Action: linux.SECCOMP_RET_TRAP, Vsyscall: true, @@ -170,10 +153,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the initial process creation. syscall.SYS_WAIT4: {}, - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, - }, - syscall.SYS_EXIT: {}, + syscall.SYS_EXIT: {}, // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ @@ -193,6 +173,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }, Action: linux.SECCOMP_RET_ALLOW, }) + + rules = appendArchSeccompRules(rules) } instrs, err := seccomp.BuildProgram(rules, defaultAction) if err != nil { @@ -264,9 +246,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0) } - // Enable cpuid-faulting; this may fail on older kernels or hardware, - // so we just disregard the result. Host CPUID will be enabled. - syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0) + // Enable cpuid-faulting. + enableCpuidFault() // Call the stub; should not return. stubCall(stubStart, ppid) diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index de6783fb0..2e6fbe488 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -25,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/hostcpu" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, @@ -49,20 +50,6 @@ func unmaskAllSignals() syscall.Errno { return errno } -// getCPU gets the current CPU. -// -// Precondition: the current runtime thread should be locked. -func getCPU() (uint32, error) { - var cpu uintptr - if _, _, errno := syscall.RawSyscall( - unix.SYS_GETCPU, - uintptr(unsafe.Pointer(&cpu)), - 0, 0); errno != 0 { - return 0, errno - } - return uint32(cpu), nil -} - // setCPU sets the CPU affinity. func (t *thread) setCPU(cpu uint32) error { mask := maskPool.Get().([]uintptr) @@ -93,10 +80,8 @@ func (t *thread) setCPU(cpu uint32) error { // // Precondition: the current runtime thread should be locked. func (t *thread) bind() { - currentCPU, err := getCPU() - if err != nil { - return - } + currentCPU := hostcpu.GetCPU() + if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { // Set the affinity on the thread and save the CPU for next // round; we don't expect CPUs to bounce around too frequently. diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index b80a3604d..2ae6b9f9d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 48b0ceaec..87f4552b5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -4,7 +4,7 @@ load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) go_template( - name = "defs", + name = "defs_amd64", srcs = [ "defs.go", "defs_amd64.go", @@ -14,11 +14,29 @@ go_template( visibility = [":__subpackages__"], ) +go_template( + name = "defs_arm64", + srcs = [ + "aarch64.go", + "defs.go", + "defs_arm64.go", + "offsets_arm64.go", + ], + visibility = [":__subpackages__"], +) + go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", package = "ring0", - template = ":defs", + template = ":defs_amd64", +) + +go_template_instance( + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", + package = "ring0", + template = ":defs_arm64", ) genrule( @@ -29,17 +47,31 @@ genrule( tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) +genrule( + name = "entry_impl_arm64", + srcs = ["entry_arm64.s"], + outs = ["entry_impl_arm64.s"], + cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) + go_library( name = "ring0", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "entry_amd64.go", + "entry_arm64.go", "entry_impl_amd64.s", + "entry_impl_arm64.s", "kernel.go", "kernel_amd64.go", + "kernel_arm64.go", "kernel_unsafe.go", "lib_amd64.go", "lib_amd64.s", + "lib_arm64.go", + "lib_arm64.s", "ring0.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go new file mode 100644 index 000000000..6b078cd1e --- /dev/null +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// Useful bits. +const ( + _PGD_PGT_BASE = 0x1000 + _PGD_PGT_SIZE = 0x1000 + _PUD_PGT_BASE = 0x2000 + _PUD_PGT_SIZE = 0x1000 + _PMD_PGT_BASE = 0x3000 + _PMD_PGT_SIZE = 0x4000 + _PTE_PGT_BASE = 0x7000 + _PTE_PGT_SIZE = 0x1000 + + _PSR_MODE_EL0t = 0x0 + _PSR_MODE_EL1t = 0x4 + _PSR_MODE_EL1h = 0x5 + _PSR_EL_MASK = 0xf + + _PSR_D_BIT = 0x200 + _PSR_A_BIT = 0x100 + _PSR_I_BIT = 0x80 + _PSR_F_BIT = 0x40 +) + +const ( + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h + + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t + + KernelFlagsClear = _PSR_EL_MASK + UserFlagsClear = _PSR_EL_MASK + + PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT +) + +// Vector is an exception vector. +type Vector uintptr + +// Exception vectors. +const ( + El1SyncInvalid = iota + El1IrqInvalid + El1FiqInvalid + El1ErrorInvalid + El1Sync + El1Irq + El1Fiq + El1Error + El0Sync + El0Irq + El0Fiq + El0Error + El0Sync_invalid + El0Irq_invalid + El0Fiq_invalid + El0Error_invalid + El1Sync_da + El1Sync_ia + El1Sync_sp_pc + El1Sync_undef + El1Sync_dbg + El1Sync_inv + El0Sync_svc + El0Sync_da + El0Sync_ia + El0Sync_fpsimd_acc + El0Sync_sve_acc + El0Sync_sys + El0Sync_sp_pc + El0Sync_undef + El0Sync_dbg + El0Sync_inv + VirtualizationException + _NR_INTERRUPTS +) + +// System call vectors. +const ( + Syscall Vector = El0Sync_svc + PageFault Vector = El0Sync_da +) + +// VirtualAddressBits returns the number bits available for virtual addresses. +func VirtualAddressBits() uint32 { + return 48 +} + +// PhysicalAddressBits returns the number of bits available for physical addresses. +func PhysicalAddressBits() uint32 { + return 40 +} diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 076063f85..3f094c2a7 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -20,17 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" ) -var ( - // UserspaceSize is the total size of userspace. - UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) - - // MaximumUserAddress is the largest possible user address. - MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) - - // KernelStartAddress is the starting kernel address. - KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) -) - // Kernel is a global kernel object. // // This contains global state, shared by multiple CPUs. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 7206322b1..10dbd381f 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -20,6 +20,17 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + // Segment indices and Selectors. const ( // Index into GDT array. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go new file mode 100644 index 000000000..dc0eeec01 --- /dev/null +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -0,0 +1,136 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +var ( + // UserspaceSize is the total size of userspace. + UserspaceSize = uintptr(1) << (VirtualAddressBits()) + + // MaximumUserAddress is the largest possible user address. + MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1) + + // KernelStartAddress is the starting kernel address. + KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) +) + +// KernelOpts has initialization options for the kernel. +type KernelOpts struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables +} + +// KernelArchState contains architecture-specific state. +type KernelArchState struct { + KernelOpts +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { + // stack is the stack used for interrupts on this CPU. + stack [512]byte + + // errorCode is the error code from the last exception. + errorCode uintptr + + // errorType indicates the type of error code here, it is always set + // along with the errorCode value above. + // + // It will either by 1, which indicates a user error, or 0 indicating a + // kernel error. If the error code below returns false (kernel error), + // then it cannot provide relevant information about the last + // exception. + errorType uintptr + + // faultAddr is the value of far_el1. + faultAddr uintptr + + // ttbr0Kvm is the value of ttbr0_el1 for sentry. + ttbr0Kvm uintptr + + // ttbr0App is the value of ttbr0_el1 for applicaton. + ttbr0App uintptr + + // exception vector. + vecCode Vector + + // application context pointer. + appAddr uintptr + + // lazyVFP is the value of cpacr_el1. + lazyVFP uintptr +} + +// ErrorCode returns the last error code. +// +// The returned boolean indicates whether the error code corresponds to the +// last user error or not. If it does not, then fault information must be +// ignored. This is generally the result of a kernel fault while servicing a +// user fault. +// +//go:nosplit +func (c *CPU) ErrorCode() (value uintptr, user bool) { + return c.errorCode, c.errorType != 0 +} + +// ClearErrorCode resets the error code. +// +//go:nosplit +func (c *CPU) ClearErrorCode() { + c.errorCode = 0 // No code. + c.errorType = 1 // User mode. +} + +//go:nosplit +func (c *CPU) GetFaultAddr() (value uintptr) { + return c.faultAddr +} + +//go:nosplit +func (c *CPU) SetTtbr0Kvm(value uintptr) { + c.ttbr0Kvm = value +} + +//go:nosplit +func (c *CPU) SetTtbr0App(value uintptr) { + c.ttbr0App = value +} + +//go:nosplit +func (c *CPU) GetVector() (value Vector) { + return c.vecCode +} + +//go:nosplit +func (c *CPU) SetAppAddr(value uintptr) { + c.appAddr = value +} + +// SwitchArchOpts are embedded in SwitchOpts. +type SwitchArchOpts struct { + // UserASID indicates that the application ASID to be used on switch, + UserASID uint16 + + // KernelASID indicates that the kernel ASID to be used on return, + KernelASID uint16 +} + +func init() { +} diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go new file mode 100644 index 000000000..0dfa42c36 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.go @@ -0,0 +1,60 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// This is an assembly function. +// +// The sysenter function is invoked in two situations: +// +// (1) The guest kernel has executed a system call. +// (2) The guest application has executed a system call. +// +// The interrupt flag is examined to determine whether the system call was +// executed from kernel mode or not and the appropriate stub is called. + +func El1_sync_invalid() +func El1_irq_invalid() +func El1_fiq_invalid() +func El1_error_invalid() + +func El1_sync() +func El1_irq() +func El1_fiq() +func El1_error() + +func El0_sync() +func El0_irq() +func El0_fiq() +func El0_error() + +func El0_sync_invalid() +func El0_irq_invalid() +func El0_fiq_invalid() +func El0_error_invalid() + +func Vectors() + +// Start is the CPU entrypoint. +// +// The CPU state will be set to c.Registers(). +func Start() +func kernelExitToEl1() + +func kernelExitToEl0() + +// Shutdown execution +func Shutdown() diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s new file mode 100644 index 000000000..add2c3e08 --- /dev/null +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -0,0 +1,591 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "funcdata.h" +#include "textflag.h" + +// NB: Offsets are programatically generated (see BUILD). +// +// This file is concatenated with the definitions. + +// Saves a register set. +// +// This is a macro because it may need to executed in contents where a stack is +// not available for calls. +// + +#define ERET() \ + WORD $0xd69f03e0 + +#define RSV_REG R18_PLATFORM +#define RSV_REG_APP R9 + +#define FPEN_NOTRAP 0x3 +#define FPEN_SHIFT 20 + +#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) + +#define REGISTERS_SAVE(reg, offset) \ + MOVD R0, offset+PTRACE_R0(reg); \ + MOVD R1, offset+PTRACE_R1(reg); \ + MOVD R2, offset+PTRACE_R2(reg); \ + MOVD R3, offset+PTRACE_R3(reg); \ + MOVD R4, offset+PTRACE_R4(reg); \ + MOVD R5, offset+PTRACE_R5(reg); \ + MOVD R6, offset+PTRACE_R6(reg); \ + MOVD R7, offset+PTRACE_R7(reg); \ + MOVD R8, offset+PTRACE_R8(reg); \ + MOVD R10, offset+PTRACE_R10(reg); \ + MOVD R11, offset+PTRACE_R11(reg); \ + MOVD R12, offset+PTRACE_R12(reg); \ + MOVD R13, offset+PTRACE_R13(reg); \ + MOVD R14, offset+PTRACE_R14(reg); \ + MOVD R15, offset+PTRACE_R15(reg); \ + MOVD R16, offset+PTRACE_R16(reg); \ + MOVD R17, offset+PTRACE_R17(reg); \ + MOVD R19, offset+PTRACE_R19(reg); \ + MOVD R20, offset+PTRACE_R20(reg); \ + MOVD R21, offset+PTRACE_R21(reg); \ + MOVD R22, offset+PTRACE_R22(reg); \ + MOVD R23, offset+PTRACE_R23(reg); \ + MOVD R24, offset+PTRACE_R24(reg); \ + MOVD R25, offset+PTRACE_R25(reg); \ + MOVD R26, offset+PTRACE_R26(reg); \ + MOVD R27, offset+PTRACE_R27(reg); \ + MOVD g, offset+PTRACE_R28(reg); \ + MOVD R29, offset+PTRACE_R29(reg); \ + MOVD R30, offset+PTRACE_R30(reg); + +#define REGISTERS_LOAD(reg, offset) \ + MOVD offset+PTRACE_R0(reg), R0; \ + MOVD offset+PTRACE_R1(reg), R1; \ + MOVD offset+PTRACE_R2(reg), R2; \ + MOVD offset+PTRACE_R3(reg), R3; \ + MOVD offset+PTRACE_R4(reg), R4; \ + MOVD offset+PTRACE_R5(reg), R5; \ + MOVD offset+PTRACE_R6(reg), R6; \ + MOVD offset+PTRACE_R7(reg), R7; \ + MOVD offset+PTRACE_R8(reg), R8; \ + MOVD offset+PTRACE_R10(reg), R10; \ + MOVD offset+PTRACE_R11(reg), R11; \ + MOVD offset+PTRACE_R12(reg), R12; \ + MOVD offset+PTRACE_R13(reg), R13; \ + MOVD offset+PTRACE_R14(reg), R14; \ + MOVD offset+PTRACE_R15(reg), R15; \ + MOVD offset+PTRACE_R16(reg), R16; \ + MOVD offset+PTRACE_R17(reg), R17; \ + MOVD offset+PTRACE_R19(reg), R19; \ + MOVD offset+PTRACE_R20(reg), R20; \ + MOVD offset+PTRACE_R21(reg), R21; \ + MOVD offset+PTRACE_R22(reg), R22; \ + MOVD offset+PTRACE_R23(reg), R23; \ + MOVD offset+PTRACE_R24(reg), R24; \ + MOVD offset+PTRACE_R25(reg), R25; \ + MOVD offset+PTRACE_R26(reg), R26; \ + MOVD offset+PTRACE_R27(reg), R27; \ + MOVD offset+PTRACE_R28(reg), g; \ + MOVD offset+PTRACE_R29(reg), R29; \ + MOVD offset+PTRACE_R30(reg), R30; + +//NOP +#define nop31Instructions() \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; \ + WORD $0xd503201f; + +#define ESR_ELx_EC_UNKNOWN (0x00) +#define ESR_ELx_EC_WFx (0x01) +/* Unallocated EC: 0x02 */ +#define ESR_ELx_EC_CP15_32 (0x03) +#define ESR_ELx_EC_CP15_64 (0x04) +#define ESR_ELx_EC_CP14_MR (0x05) +#define ESR_ELx_EC_CP14_LS (0x06) +#define ESR_ELx_EC_FP_ASIMD (0x07) +#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */ +#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */ +/* Unallocated EC: 0x0A - 0x0B */ +#define ESR_ELx_EC_CP14_64 (0x0C) +/* Unallocated EC: 0x0d */ +#define ESR_ELx_EC_ILL (0x0E) +/* Unallocated EC: 0x0F - 0x10 */ +#define ESR_ELx_EC_SVC32 (0x11) +#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */ +#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */ +/* Unallocated EC: 0x14 */ +#define ESR_ELx_EC_SVC64 (0x15) +#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */ +#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */ +#define ESR_ELx_EC_SYS64 (0x18) +#define ESR_ELx_EC_SVE (0x19) +/* Unallocated EC: 0x1A - 0x1E */ +#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */ +#define ESR_ELx_EC_IABT_LOW (0x20) +#define ESR_ELx_EC_IABT_CUR (0x21) +#define ESR_ELx_EC_PC_ALIGN (0x22) +/* Unallocated EC: 0x23 */ +#define ESR_ELx_EC_DABT_LOW (0x24) +#define ESR_ELx_EC_DABT_CUR (0x25) +#define ESR_ELx_EC_SP_ALIGN (0x26) +/* Unallocated EC: 0x27 */ +#define ESR_ELx_EC_FP_EXC32 (0x28) +/* Unallocated EC: 0x29 - 0x2B */ +#define ESR_ELx_EC_FP_EXC64 (0x2C) +/* Unallocated EC: 0x2D - 0x2E */ +#define ESR_ELx_EC_SERROR (0x2F) +#define ESR_ELx_EC_BREAKPT_LOW (0x30) +#define ESR_ELx_EC_BREAKPT_CUR (0x31) +#define ESR_ELx_EC_SOFTSTP_LOW (0x32) +#define ESR_ELx_EC_SOFTSTP_CUR (0x33) +#define ESR_ELx_EC_WATCHPT_LOW (0x34) +#define ESR_ELx_EC_WATCHPT_CUR (0x35) +/* Unallocated EC: 0x36 - 0x37 */ +#define ESR_ELx_EC_BKPT32 (0x38) +/* Unallocated EC: 0x39 */ +#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */ +/* Unallocted EC: 0x3B */ +#define ESR_ELx_EC_BRK64 (0x3C) +/* Unallocated EC: 0x3D - 0x3F */ +#define ESR_ELx_EC_MAX (0x3F) + +#define ESR_ELx_EC_SHIFT (26) +#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT) +#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT) + +#define ESR_ELx_IL_SHIFT (25) +#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT) +#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1) + +/* ISS field definitions shared by different classes */ +#define ESR_ELx_WNR_SHIFT (6) +#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT) + +/* Asynchronous Error Type */ +#define ESR_ELx_IDS_SHIFT (24) +#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT) +#define ESR_ELx_AET_SHIFT (10) +#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT) + +#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT) +#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT) + +/* Shared ISS field definitions for Data/Instruction aborts */ +#define ESR_ELx_SET_SHIFT (11) +#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT) +#define ESR_ELx_FnV_SHIFT (10) +#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT) +#define ESR_ELx_EA_SHIFT (9) +#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT) +#define ESR_ELx_S1PTW_SHIFT (7) +#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT) + +/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */ +#define ESR_ELx_FSC (0x3F) +#define ESR_ELx_FSC_TYPE (0x3C) +#define ESR_ELx_FSC_EXTABT (0x10) +#define ESR_ELx_FSC_SERROR (0x11) +#define ESR_ELx_FSC_ACCESS (0x08) +#define ESR_ELx_FSC_FAULT (0x04) +#define ESR_ELx_FSC_PERM (0x0C) + +/* ISS field definitions for Data Aborts */ +#define ESR_ELx_ISV_SHIFT (24) +#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT) +#define ESR_ELx_SAS_SHIFT (22) +#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT) +#define ESR_ELx_SSE_SHIFT (21) +#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT) +#define ESR_ELx_SRT_SHIFT (16) +#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT) +#define ESR_ELx_SF_SHIFT (15) +#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT) +#define ESR_ELx_AR_SHIFT (14) +#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT) +#define ESR_ELx_CM_SHIFT (8) +#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT) + +/* ISS field definitions for exceptions taken in to Hyp */ +#define ESR_ELx_CV (UL(1) << 24) +#define ESR_ELx_COND_SHIFT (20) +#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT) +#define ESR_ELx_WFx_ISS_TI (UL(1) << 0) +#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0) +#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) +#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) + +#define LOAD_KERNEL_ADDRESS(from, to) \ + MOVD from, to; \ + ORR $0xffff000000000000, to, to; + +// LOAD_KERNEL_STACK loads the kernel temporary stack. +#define LOAD_KERNEL_STACK(from) \ + LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \ + MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \ + MOVD RSV_REG, RSP; \ + ISB $15; \ + DSB $15; + +#define SWITCH_TO_APP_PAGETABLE(from) \ + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define SWITCH_TO_KVM_PAGETABLE(from) \ + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + ISB $15; \ + DSB $15; + +#define IRQ_ENABLE \ + MSR $2, DAIFSet; + +#define IRQ_DISABLE \ + MSR $2, DAIFClr; + +#define VFP_ENABLE \ + MOVD $FPEN_ENABLE, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define VFP_DISABLE \ + MOVD $0x0, R0; \ + WORD $0xd5181040; \ //MSR R0, CPACR_EL1 + ISB $15; + +#define KERNEL_ENTRY_FROM_EL0 \ + SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. + STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. + SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + MOVD RSV_REG_APP, R20; \ + LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ + ADD $16, RSP, RSP; \ + MOVD RSV_REG, PTRACE_R18(R20); \ + MOVD RSV_REG_APP, PTRACE_R9(R20); \ + MOVD R20, RSV_REG_APP; \ + WORD $0xd5384003; \ // MRS SPSR_EL1, R3 + MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MRS ELR_EL1, R3; \ + MOVD R3, PTRACE_PC(RSV_REG_APP); \ + WORD $0xd5384103; \ // MRS SP_EL0, R3 + MOVD R3, PTRACE_SP(RSV_REG_APP); + +#define KERNEL_ENTRY_FROM_EL1 \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context + MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \ + WORD $0xd5384004; \ // MRS SPSR_EL1, R4 + MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \ + MRS ELR_EL1, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \ + MOVD RSP, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); + +TEXT ·Halt(SB),NOSPLIT,$0 + // clear bluepill. + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + CMP RSV_REG, R9 + BNE mmio_exit + MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG) +mmio_exit: + // Disable fpsimd. + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, CPU_LAZY_VFP(RSV_REG) + VFP_DISABLE + + // MMIO_EXIT. + MOVD $0, R9 + MOVD R0, 0xffff000000001000(R9) + B ·kernelExitToEl1(SB) + +TEXT ·Shutdown(SB),NOSPLIT,$0 + // PSCI EVENT. + MOVD $0x84000009, R0 + HVC $0 + +// See kernel.go. +TEXT ·Current(SB),NOSPLIT,$0-8 + MOVD CPU_SELF(RSV_REG), R8 + MOVD R8, ret+0(FP) + RET + +#define STACK_FRAME_SIZE 16 + +TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 + ERET() + +TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 + ERET() + +TEXT ·Start(SB),NOSPLIT,$0 + IRQ_DISABLE + MOVD R8, RSV_REG + ORR $0xffff000000000000, RSV_REG, RSV_REG + WORD $0xd518d092 //MSR R18, TPIDR_EL1 + + B ·kernelExitToEl1(SB) + +TEXT ·El1_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL1 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_DABT_CUR, R24 + BEQ el1_da + CMP $ESR_ELx_EC_IABT_CUR, R24 + BEQ el1_ia + CMP $ESR_ELx_EC_SYS64, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el1_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el1_undef + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el1_svc + CMP $ESR_ELx_EC_BREAKPT_CUR, R24 + BGE el1_dbg + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el1_fpsimd_acc + B el1_invalid + +el1_da: + B ·Halt(SB) + +el1_ia: + B ·Halt(SB) + +el1_sp_pc: + B ·Shutdown(SB) + +el1_undef: + B ·Shutdown(SB) + +el1_svc: + B ·Halt(SB) + +el1_dbg: + B ·Shutdown(SB) + +el1_fpsimd_acc: + VFP_ENABLE + B ·kernelExitToEl1(SB) // Resume. + +el1_invalid: + B ·Shutdown(SB) + +TEXT ·El1_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El1_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync(SB),NOSPLIT,$0 + KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + LSR $ESR_ELx_EC_SHIFT, R25, R24 + CMP $ESR_ELx_EC_SVC64, R24 + BEQ el0_svc + CMP $ESR_ELx_EC_DABT_LOW, R24 + BEQ el0_da + CMP $ESR_ELx_EC_IABT_LOW, R24 + BEQ el0_ia + CMP $ESR_ELx_EC_FP_ASIMD, R24 + BEQ el0_fpsimd_acc + CMP $ESR_ELx_EC_SVE, R24 + BEQ el0_sve_acc + CMP $ESR_ELx_EC_FP_EXC64, R24 + BEQ el0_fpsimd_exc + CMP $ESR_ELx_EC_SP_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_PC_ALIGN, R24 + BEQ el0_sp_pc + CMP $ESR_ELx_EC_UNKNOWN, R24 + BEQ el0_undef + CMP $ESR_ELx_EC_BREAKPT_LOW, R24 + BGE el0_dbg + B el0_invalid + +el0_svc: + B ·Halt(SB) + +el0_da: + B ·Halt(SB) + +el0_ia: + B ·Shutdown(SB) + +el0_fpsimd_acc: + B ·Shutdown(SB) + +el0_sve_acc: + B ·Shutdown(SB) + +el0_fpsimd_exc: + B ·Shutdown(SB) + +el0_sp_pc: + B ·Shutdown(SB) + +el0_undef: + B ·Shutdown(SB) + +el0_dbg: + B ·Shutdown(SB) + +el0_invalid: + B ·Shutdown(SB) + +TEXT ·El0_irq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_irq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·El0_error_invalid(SB),NOSPLIT,$0 + B ·Shutdown(SB) + +TEXT ·Vectors(SB),NOSPLIT,$0 + B ·El1_sync_invalid(SB) + nop31Instructions() + B ·El1_irq_invalid(SB) + nop31Instructions() + B ·El1_fiq_invalid(SB) + nop31Instructions() + B ·El1_error_invalid(SB) + nop31Instructions() + + B ·El1_sync(SB) + nop31Instructions() + B ·El1_irq(SB) + nop31Instructions() + B ·El1_fiq(SB) + nop31Instructions() + B ·El1_error(SB) + nop31Instructions() + + B ·El0_sync(SB) + nop31Instructions() + B ·El0_irq(SB) + nop31Instructions() + B ·El0_fiq(SB) + nop31Instructions() + B ·El0_error(SB) + nop31Instructions() + + B ·El0_sync_invalid(SB) + nop31Instructions() + B ·El0_irq_invalid(SB) + nop31Instructions() + B ·El0_fiq_invalid(SB) + nop31Instructions() + B ·El0_error_invalid(SB) + nop31Instructions() + + WORD $0xd503201f //nop + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() + WORD $0xd503201f + nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 780bf9a66..42076fb04 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -4,16 +4,24 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs", + template = "//pkg/sentry/platform/ring0:defs_arm64", +) + +go_template_instance( + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", + package = "main", + template = "//pkg/sentry/platform/ring0:defs_amd64", ) go_binary( name = "gen_offsets", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "main.go", ], visibility = ["//pkg/sentry/platform/ring0:__pkg__"], diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go new file mode 100644 index 000000000..ed82a131e --- /dev/null +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -0,0 +1,58 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// init initializes architecture-specific state. +func (k *Kernel) init(opts KernelOpts) { + // Save the root page tables. + k.PageTables = opts.PageTables +} + +// init initializes architecture-specific state. +func (c *CPU) init() { + // Set the kernel stack pointer(virtual address). + c.registers.Sp = uint64(c.StackTop()) + +} + +// StackTop returns the kernel's stack address. +// +//go:nosplit +func (c *CPU) StackTop() uint64 { + return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack)) +} + +// IsCanonical indicates whether addr is canonical per the arm64 spec. +// +//go:nosplit +func IsCanonical(addr uint64) bool { + return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 +} + +//go:nosplit +func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + // Sanitize registers. + regs := switchOpts.Registers + + regs.Pstate &= ^uint64(UserFlagsClear) + regs.Pstate |= UserFlagsSet + kernelExitToEl0() + vector = c.vecCode + + // Perform the switch. + return +} diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go new file mode 100644 index 000000000..8bcfe1032 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -0,0 +1,39 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +// CPACREL1 returns the value of the CPACR_EL1 register. +func CPACREL1() (value uintptr) + +// FPCR returns the value of FPCR register. +func FPCR() (value uintptr) + +// SetFPCR writes the FPCR value. +func SetFPCR(value uintptr) + +// FPSR returns the value of FPSR register. +func FPSR() (value uintptr) + +// SetFPSR writes the FPSR value. +func SetFPSR(value uintptr) + +// SaveVRegs saves V0-V31 registers. +// V0-V31: 32 128-bit registers for floating point and simd. +func SaveVRegs(*byte) + +// LoadVRegs loads V0-V31 registers. +func LoadVRegs(*byte) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s new file mode 100644 index 000000000..1c9171004 --- /dev/null +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -0,0 +1,118 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +TEXT ·CPACREL1(SB),NOSPLIT,$0-8 + WORD $0xd5381041 // MRS CPACR_EL1, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + WORD $0xd53b4201 // MRS NZCV, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·GetFPSR(SB),NOSPLIT,$0-8 + WORD $0xd53b4421 // MRS FPSR, R1 + MOVD R1, ret+0(FP) + RET + +TEXT ·FPCR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4201 // MSR R1, NZCV + RET + +TEXT ·SetFPSR(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + WORD $0xd51b4421 // MSR R1, FPSR + RET + +TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD F0, 16*1(R0) + FMOVD F1, 16*2(R0) + FMOVD F2, 16*3(R0) + FMOVD F3, 16*4(R0) + FMOVD F4, 16*5(R0) + FMOVD F5, 16*6(R0) + FMOVD F6, 16*7(R0) + FMOVD F7, 16*8(R0) + FMOVD F8, 16*9(R0) + FMOVD F9, 16*10(R0) + FMOVD F10, 16*11(R0) + FMOVD F11, 16*12(R0) + FMOVD F12, 16*13(R0) + FMOVD F13, 16*14(R0) + FMOVD F14, 16*15(R0) + FMOVD F15, 16*16(R0) + FMOVD F16, 16*17(R0) + FMOVD F17, 16*18(R0) + FMOVD F18, 16*19(R0) + FMOVD F19, 16*20(R0) + FMOVD F20, 16*21(R0) + FMOVD F21, 16*22(R0) + FMOVD F22, 16*23(R0) + FMOVD F23, 16*24(R0) + FMOVD F24, 16*25(R0) + FMOVD F25, 16*26(R0) + FMOVD F26, 16*27(R0) + FMOVD F27, 16*28(R0) + FMOVD F28, 16*29(R0) + FMOVD F29, 16*30(R0) + FMOVD F30, 16*31(R0) + FMOVD F31, 16*32(R0) + ISB $15 + + RET + +TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R0 + + // Skip aarch64_ctx, fpsr, fpcr. + FMOVD 16*1(R0), F0 + FMOVD 16*2(R0), F1 + FMOVD 16*3(R0), F2 + FMOVD 16*4(R0), F3 + FMOVD 16*5(R0), F4 + FMOVD 16*6(R0), F5 + FMOVD 16*7(R0), F6 + FMOVD 16*8(R0), F7 + FMOVD 16*9(R0), F8 + FMOVD 16*10(R0), F9 + FMOVD 16*11(R0), F10 + FMOVD 16*12(R0), F11 + FMOVD 16*13(R0), F12 + FMOVD 16*14(R0), F13 + FMOVD 16*15(R0), F14 + FMOVD 16*16(R0), F15 + FMOVD 16*17(R0), F16 + FMOVD 16*18(R0), F17 + FMOVD 16*19(R0), F18 + FMOVD 16*20(R0), F19 + FMOVD 16*21(R0), F20 + FMOVD 16*22(R0), F21 + FMOVD 16*23(R0), F22 + FMOVD 16*24(R0), F23 + FMOVD 16*25(R0), F24 + FMOVD 16*26(R0), F25 + FMOVD 16*27(R0), F26 + FMOVD 16*28(R0), F27 + FMOVD 16*29(R0), F28 + FMOVD 16*30(R0), F29 + FMOVD 16*31(R0), F30 + FMOVD 16*32(R0), F31 + ISB $15 + + RET diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go new file mode 100644 index 000000000..cd2a65f97 --- /dev/null +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -0,0 +1,125 @@ +// Copyright 2019 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package ring0 + +import ( + "fmt" + "io" + "reflect" + "syscall" +) + +// Emit prints architecture-specific offsets. +func Emit(w io.Writer) { + fmt.Fprintf(w, "// Automatically generated, do not edit.\n") + + c := &CPU{} + fmt.Fprintf(w, "\n// CPU offsets.\n") + fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) + fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + + fmt.Fprintf(w, "\n// Bits.\n") + fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) + + fmt.Fprintf(w, "\n// Vectors.\n") + fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) + fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) + fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) + fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) + + fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) + fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) + fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) + fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + + fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) + fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) + fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) + fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + + fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) + fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) + fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) + fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + + fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) + fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) + fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) + fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) + fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) + fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + + fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) + fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) + fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) + fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) + fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) + fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) + fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) + fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) + fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) + fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + + fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) + fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) + + p := &syscall.PtraceRegs{} + fmt.Fprintf(w, "\n// Ptrace registers.\n") + fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer()) +} diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 934a90378..e2e15ba5c 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,14 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +config_setting( + name = "aarch64", + constraint_values = ["@bazel_tools//platforms:aarch64"], +) + go_template( name = "generic_walker", - srcs = [ - "walker_amd64.go", - ], + srcs = ["walker_amd64.go"], opt_types = [ "Visitor", ], @@ -76,9 +79,13 @@ go_library( "allocator.go", "allocator_unsafe.go", "pagetables.go", + "pagetables_aarch64.go", "pagetables_amd64.go", + "pagetables_arm64.go", "pagetables_x86.go", "pcids_x86.go", + "walker_amd64.go", + "walker_arm64.go", "walker_empty.go", "walker_lookup.go", "walker_map.go", @@ -97,6 +104,7 @@ go_test( size = "small", srcs = [ "pagetables_amd64_test.go", + "pagetables_arm64_test.go", "pagetables_test.go", "walker_check.go", ], diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 904f1a6de..30c64a372 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -48,15 +48,6 @@ func New(a Allocator) *PageTables { return p } -// Init initializes a set of PageTables. -// -//go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) -} - // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go new file mode 100644 index 000000000..e78424766 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -0,0 +1,212 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +// archPageTables is architecture-specific data. +type archPageTables struct { + // root is the pagetable root for kernel space. + root *PTEs + + // rootPhysical is the cached physical address of the root. + // + // This is saved only to prevent constant translation. + rootPhysical uintptr + + asid uint16 +} + +// TTBR0_EL1 returns the translation table base register 0. +// +//go:nosplit +func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// TTBR1_EL1 returns the translation table base register 1. +// +//go:nosplit +func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { + return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset +} + +// Bits in page table entries. +const ( + typeTable = 0x3 << 0 + typeSect = 0x1 << 0 + typePage = 0x3 << 0 + pteValid = 0x1 << 0 + pteTableBit = 0x1 << 1 + pteTypeMask = 0x3 << 0 + present = pteValid | pteTableBit + user = 0x1 << 6 /* AP[1] */ + readOnly = 0x1 << 7 /* AP[2] */ + accessed = 0x1 << 10 + dbm = 0x1 << 51 + writable = dbm + cont = 0x1 << 52 + pxn = 0x1 << 53 + xn = 0x1 << 54 + dirty = 0x1 << 55 + nG = 0x1 << 11 + shared = 0x3 << 8 +) + +const ( + mtNormal = 0x4 << 2 +) + +const ( + executeDisable = xn + optionMask = 0xfff | 0xfff<<48 + protDefault = accessed | shared | mtNormal +) + +// MapOpts are x86 options. +type MapOpts struct { + // AccessType defines permissions. + AccessType usermem.AccessType + + // Global indicates the page is globally accessible. + Global bool + + // User indicates the page is a user page. + User bool +} + +// PTE is a page table entry. +type PTE uintptr + +// Clear clears this PTE, including sect page information. +// +//go:nosplit +func (p *PTE) Clear() { + atomic.StoreUintptr((*uintptr)(p), 0) +} + +// Valid returns true iff this entry is valid. +// +//go:nosplit +func (p *PTE) Valid() bool { + return atomic.LoadUintptr((*uintptr)(p))&present != 0 +} + +// Opts returns the PTE options. +// +// These are all options except Valid and Sect. +// +//go:nosplit +func (p *PTE) Opts() MapOpts { + v := atomic.LoadUintptr((*uintptr)(p)) + + return MapOpts{ + AccessType: usermem.AccessType{ + Read: true, + Write: v&readOnly == 0, + Execute: v&xn == 0, + }, + Global: v&nG == 0, + User: v&user != 0, + } +} + +// SetSect sets this page as a sect page. +// +// The page must not be valid or a panic will result. +// +//go:nosplit +func (p *PTE) SetSect() { + if p.Valid() { + // This is not allowed. + panic("SetSect called on valid page!") + } + atomic.StoreUintptr((*uintptr)(p), typeSect) +} + +// IsSect returns true iff this page is a sect page. +// +//go:nosplit +func (p *PTE) IsSect() bool { + return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect +} + +// Set sets this PTE value. +// +// This does not change the sect page property. +// +//go:nosplit +func (p *PTE) Set(addr uintptr, opts MapOpts) { + if !opts.AccessType.Any() { + p.Clear() + return + } + v := (addr &^ optionMask) | protDefault | nG | readOnly + + if p.IsSect() { + // Note that this is inherited from the previous instance. Set + // does not change the value of Sect. See above. + v |= typeSect + } else { + v |= typePage + } + + if opts.Global { + v = v &^ nG + } + + if opts.AccessType.Execute { + v = v &^ executeDisable + } else { + v |= executeDisable + } + if opts.AccessType.Write { + v = v &^ readOnly + } + + if opts.User { + v |= user + } else { + v = v &^ user + } + atomic.StoreUintptr((*uintptr)(p), v) +} + +// setPageTable sets this PTE value and forces the write bit and sect bit to +// be cleared. This is used explicitly for breaking sect pages. +// +//go:nosplit +func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) { + addr := pt.Allocator.PhysicalFor(ptes) + if addr&^optionMask != addr { + // This should never happen. + panic("unaligned physical address!") + } + v := addr | typeTable | protDefault + atomic.StoreUintptr((*uintptr)(p), v) +} + +// Address extracts the address. This should only be used if Valid returns true. +// +//go:nosplit +func (p *PTE) Address() uintptr { + return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 7aa6c524e..0c153cf8c 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -41,5 +41,14 @@ const ( entriesPerPage = 512 ) +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go new file mode 100644 index 000000000..1a49f12a2 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -0,0 +1,57 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +// Address constraints. +// +// The lowerTop and upperBottom currently apply to four-level pagetables; +// additional refactoring would be necessary to support five-level pagetables. +const ( + lowerTop = 0x0000ffffffffffff + upperBottom = 0xffff000000000000 + pteShift = 12 + pmdShift = 21 + pudShift = 30 + pgdShift = 39 + + pteMask = 0x1ff << pteShift + pmdMask = 0x1ff << pmdShift + pudMask = 0x1ff << pudShift + pgdMask = 0x1ff << pgdShift + + pteSize = 1 << pteShift + pmdSize = 1 << pmdShift + pudSize = 1 << pudShift + pgdSize = 1 << pgdShift + + ttbrASIDOffset = 55 + ttbrASIDMask = 0xff + + entriesPerPage = 512 +) + +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) + p.archPageTables.root = p.Allocator.NewPTEs() + p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// PTEs is a collection of entries. +type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go new file mode 100644 index 000000000..254116233 --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go @@ -0,0 +1,80 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func Test2MAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a huge page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) + + pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) + pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, + {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, + }) +} + +func Test1GAnd4K(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a small page and a super page. + pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) + + checkMappings(t, pt, []mapping{ + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, + {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit1GPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a super page and knock out the middle. + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} + +func TestSplit2MPage(t *testing.T) { + pt := New(NewRuntimeAllocator()) + + // Map a huge page and knock out the middle. + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) + pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) + + checkMappings(t, pt, []mapping{ + {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + }) +} diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go new file mode 100644 index 000000000..c261d393a --- /dev/null +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package pagetables + +// Visitor is a generic type. +type Visitor interface { + // visit is called on each PTE. + visit(start uintptr, pte *PTE, align uintptr) + + // requiresAlloc indicates that new entries should be allocated within + // the walked range. + requiresAlloc() bool + + // requiresSplit indicates that entries in the given range should be + // split if they are huge or jumbo pages. + requiresSplit() bool +} + +// Walker walks page tables. +type Walker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor Visitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is sect pages. If a valid sect page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of sect pages whenever +// possible. Whether a sect page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// +// Precondition: start must be less than end. +// +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *Walker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func next(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} + +// iterateRangeCanonical walks a canonical range. +// +//go:nosplit +func (w *Walker) iterateRangeCanonical(start, end uintptr) { + pgdEntryIndex := w.pageTables.root + if start >= upperBottom { + pgdEntryIndex = w.pageTables.archPageTables.root + } + + for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { + var ( + pgdEntry = &pgdEntryIndex[pgdIndex] + pudEntries *PTEs + ) + if !pgdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + start = next(start, pgdSize) + continue + } + + // Allocate a new pgd. + pudEntries = w.pageTables.Allocator.NewPTEs() + pgdEntry.setPageTable(w.pageTables, pudEntries) + } else { + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + } + + // Map the next level. + clearPUDEntries := uint16(0) + + for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ { + var ( + pudEntry = &pudEntries[pudIndex] + pmdEntries *PTEs + ) + if !pudEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPUDEntries++ + start = next(start, pudSize) + continue + } + + // This level has 1-GB sect pages. Is this + // entire region at least as large as a single + // PUD entry? If so, we can skip allocating a + // new page for the pmd. + if start&(pudSize-1) == 0 && end-start >= pudSize { + pudEntry.SetSect() + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if pudEntry.Valid() { + start = next(start, pudSize) + continue + } + } + + // Allocate a new pud. + pmdEntries = w.pageTables.Allocator.NewPTEs() + pudEntry.setPageTable(w.pageTables, pmdEntries) + + } else if pudEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { + // Install the relevant entries. + pmdEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pmdEntries[index].SetSect() + pmdEntries[index].Set( + pudEntry.Address()+(pmdSize*uintptr(index)), + pudEntry.Opts()) + } + pudEntry.setPageTable(w.pageTables, pmdEntries) + } else { + // A sect page to be checked directly. + w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + + // Might have been cleared. + if !pudEntry.Valid() { + clearPUDEntries++ + } + + // Note that the sect page was changed. + start = next(start, pudSize) + continue + } + + } else { + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + } + + // Map the next level, since this is valid. + clearPMDEntries := uint16(0) + + for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ { + var ( + pmdEntry = &pmdEntries[pmdIndex] + pteEntries *PTEs + ) + if !pmdEntry.Valid() { + if !w.visitor.requiresAlloc() { + // Skip over this entry. + clearPMDEntries++ + start = next(start, pmdSize) + continue + } + + // This level has 2-MB huge pages. If this + // region is contined in a single PMD entry? + // As above, we can skip allocating a new page. + if start&(pmdSize-1) == 0 && end-start >= pmdSize { + pmdEntry.SetSect() + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if pmdEntry.Valid() { + start = next(start, pmdSize) + continue + } + } + + // Allocate a new pmd. + pteEntries = w.pageTables.Allocator.NewPTEs() + pmdEntry.setPageTable(w.pageTables, pteEntries) + + } else if pmdEntry.IsSect() { + // Does this page need to be split? + if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) { + // Install the relevant entries. + pteEntries = w.pageTables.Allocator.NewPTEs() + for index := uint16(0); index < entriesPerPage; index++ { + pteEntries[index].Set( + pmdEntry.Address()+(pteSize*uintptr(index)), + pmdEntry.Opts()) + } + pmdEntry.setPageTable(w.pageTables, pteEntries) + } else { + // A huge page to be checked directly. + w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + + // Might have been cleared. + if !pmdEntry.Valid() { + clearPMDEntries++ + } + + // Note that the huge page was changed. + start = next(start, pmdSize) + continue + } + + } else { + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + } + + // Map the next level, since this is valid. + clearPTEEntries := uint16(0) + + for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ { + var ( + pteEntry = &pteEntries[pteIndex] + ) + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { + clearPTEEntries++ + start += pteSize + continue + } + + // At this point, we are guaranteed that start%pteSize == 0. + w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !pteEntry.Valid() { + if w.visitor.requiresAlloc() { + panic("PTE not set after iteration with requiresAlloc!") + } + clearPTEEntries++ + } + + // Note that the pte was changed. + start += pteSize + continue + } + + // Check if we no longer need this page. + if clearPTEEntries == entriesPerPage { + pmdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pteEntries) + clearPMDEntries++ + } + } + + // Check if we no longer need this page. + if clearPMDEntries == entriesPerPage { + pudEntry.Clear() + w.pageTables.Allocator.FreePTEs(pmdEntries) + clearPUDEntries++ + } + } + + // Check if we no longer need this page. + if clearPUDEntries == entriesPerPage { + pgdEntry.Clear() + w.pageTables.Allocator.FreePTEs(pudEntries) + } + } +} diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go index 2f65db70b..ba1f9043d 100644 --- a/pkg/sentry/sighandling/sighandling.go +++ b/pkg/sentry/sighandling/sighandling.go @@ -16,7 +16,6 @@ package sighandling import ( - "fmt" "os" "os/signal" "reflect" @@ -31,37 +30,25 @@ const numSignals = 32 // handleSignals listens for incoming signals and calls the given handler // function. // -// It starts when the start channel is closed, stops when the stop channel -// is closed, and closes done once it will no longer deliver signals to k. -func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) { +// It stops when the stop channel is closed. The done channel is closed once it +// will no longer deliver signals to k. +func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) { // Build a select case. - sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}} + sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}} for _, sigchan := range sigchans { sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)}) } - started := false for { // Wait for a notification. index, _, ok := reflect.Select(sc) - // Was it the start / stop channel? + // Was it the stop channel? if index == 0 { if !ok { - if !started { - // start channel; start forwarding and - // swap this case for the stop channel - // to select stop requests. - started = true - sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)} - } else { - // stop channel; stop forwarding and - // clear this case so it is never - // selected again. - started = false - close(done) - sc[0].Chan = reflect.Value{} - } + // Stop forwarding and notify that it's done. + close(done) + return } continue } @@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, // Otherwise, it was a signal on channel N. Index 0 represents the stop // channel, so index N represents the channel for signal N. - signal := linux.Signal(index) - - if !started { - // Kernel cannot receive signals, either because it is - // not ready yet or is shutting down. - // - // Kill ourselves if this signal would have killed the - // process before PrepareForwarding was called. i.e., all - // _SigKill signals; see Go - // src/runtime/sigtab_linux_generic.go. - // - // Otherwise ignore the signal. - // - // TODO(b/114489875): Drop in Go 1.12, which uses tgkill - // in runtime.raise. - switch signal { - case linux.SIGHUP, linux.SIGINT, linux.SIGTERM: - dieFromSignal(signal) - panic(fmt.Sprintf("Failed to die from signal %d", signal)) - default: - continue - } - } - - // Pass the signal to the handler. - handler(signal) + handler(linux.Signal(index)) } } -// PrepareHandler ensures that synchronous signals are passed to the given -// handler function and returns a callback that starts signal delivery, which -// itself returns a callback that stops signal handling. +// StartSignalForwarding ensures that synchronous signals are passed to the +// given handler function and returns a callback that stops signal delivery. // // Note that this function permanently takes over signal handling. After the // stop callback, signals revert to the default Go runtime behavior, which // cannot be overridden with external calls to signal.Notify. -func PrepareHandler(handler func(linux.Signal)) func() func() { - start := make(chan struct{}) +func StartSignalForwarding(handler func(linux.Signal)) func() { stop := make(chan struct{}) done := make(chan struct{}) @@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() { signal.Notify(sigchan, syscall.Signal(sig)) } // Start up our listener. - go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. + go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu. - return func() func() { - close(start) - return func() { - close(stop) - <-done - } + return func() { + close(stop) + <-done } } diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index c303435d5..1ebe22d34 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -15,8 +15,6 @@ package sighandling import ( - "fmt" - "runtime" "syscall" "unsafe" @@ -48,27 +46,3 @@ func IgnoreChildStop() error { return nil } - -// dieFromSignal kills the current process with sig. -// -// Preconditions: The default action of sig is termination. -func dieFromSignal(sig linux.Signal) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - sa := sigaction{handler: linux.SIG_DFL} - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigaction failed: %v", e)) - } - - set := linux.MakeSignalSet(sig) - if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 { - panic(fmt.Sprintf("rt_sigprocmask failed: %v", e)) - } - - if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil { - panic(fmt.Sprintf("tgkill failed: %v", err)) - } - - panic("failed to die") -} diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 4a6e83a8b..357517ed4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserror", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4e95101b7..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" @@ -194,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -239,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -320,35 +321,109 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { buf, linux.SOL_TCP, linux.TCP_INQ, - 4, + t.Arch().Width(), inq, ) } +// PackTOS packs an IP_TOS socket control message. +func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_TOS, + t.Arch().Width(), + tos, + ) +} + +// PackTClass packs an IPV6_TCLASS socket control message. +func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_TCLASS, + t.Arch().Width(), + tClass, + ) +} + +// PackControlMessages packs control messages into the given buffer. +// +// We skip control messages specific to Unix domain sockets. +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { + if cmsgs.IP.HasTimestamp { + buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) + } + + if cmsgs.IP.HasInq { + // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. + buf = PackInq(t, cmsgs.IP.Inq, buf) + } + + if cmsgs.IP.HasTOS { + buf = PackTOS(t, cmsgs.IP.TOS, buf) + } + + if cmsgs.IP.HasTClass { + buf = PackTClass(t, cmsgs.IP.TClass, buf) + } + + return buf +} + +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( - fds linux.ControlMessageRights - haveCreds bool - creds linux.ControlMessageCredentials + cmsgs socket.ControlMessages + fds linux.ControlMessageRights ) for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return transport.ControlMessages{}, syserror.EINVAL + return cmsgs, syserror.EINVAL } var h linux.ControlMessageHeader binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } if h.Length > uint64(len(buf)-i) { - return transport.ControlMessages{}, syserror.EINVAL - } - if h.Level != linux.SOL_SOCKET { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } i += linux.SizeOfControlMessageHeader @@ -358,59 +433,79 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport. // sizeof(long) in CMSG_ALIGN. width := t.Arch().Width() - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) - numRights := rightsSize / linux.SizeOfControlMessageRight - - if len(fds)+numRights > linux.SCM_MAX_FD { - return transport.ControlMessages{}, syserror.EINVAL + switch h.Level { + case linux.SOL_SOCKET: + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) + numRights := rightsSize / linux.SizeOfControlMessageRight + + if len(fds)+numRights > linux.SCM_MAX_FD { + return socket.ControlMessages{}, syserror.EINVAL + } + + for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { + fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + } + + i += AlignUp(length, width) + + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + return socket.ControlMessages{}, syserror.EINVAL + } + + var creds linux.ControlMessageCredentials + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + scmCreds, err := NewSCMCredentials(t, creds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Credentials = scmCreds + i += AlignUp(length, width) + + default: + // Unknown message type. + return socket.ControlMessages{}, syserror.EINVAL } - - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + case linux.SOL_IP: + switch h.Type { + case linux.IP_TOS: + cmsgs.IP.HasTOS = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - i += AlignUp(length, width) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { - return transport.ControlMessages{}, syserror.EINVAL + case linux.SOL_IPV6: + switch h.Type { + case linux.IPV6_TCLASS: + cmsgs.IP.HasTClass = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) + i += AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) - haveCreds = true - i += AlignUp(length, width) - default: - // Unknown message type. - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } } - var credentials SCMCredentials - if haveCreds { - var err error - if credentials, err = NewSCMCredentials(t, creds); err != nil { - return transport.ControlMessages{}, err - } - } else { - credentials = makeCreds(t, socketOrEndpoint) + if cmsgs.Unix.Credentials == nil { + cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint) } - var rights SCMRights if len(fds) > 0 { - var err error - if rights, err = NewSCMRights(t, fds); err != nil { - return transport.ControlMessages{}, err + rights, err := NewSCMRights(t, fds) + if err != nil { + return socket.ControlMessages{}, err } + cmsgs.Unix.Rights = rights } - if credentials == nil && rights == nil { - return transport.ControlMessages{}, nil - } - - return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil + return cmsgs, nil } func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4d174dda4..4c44c7c0f 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -29,9 +29,12 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", + "//pkg/sentry/socket/control", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 92beb1bcf..c957b0f1d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -18,6 +18,7 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/fdnotifier" @@ -29,6 +30,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -41,6 +43,10 @@ const ( // sizeofSockaddr is the size in bytes of the largest sockaddr type // supported by this package. sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) + + // maxControlLen is the maximum size of a control message buffer used in a + // recvmsg or sendmsg syscall. + maxControlLen = 1024 ) // socketOperations implements fs.FileOperations and socket.Socket for a socket @@ -281,26 +287,32 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 - case syscall.SO_LINGER: + case linux.SO_LINGER: optlen = syscall.SizeofLinger } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 - case syscall.TCP_INFO: + case linux.TCP_INFO: optlen = int(linux.SizeOfTCPInfo) } } + if optlen == 0 { return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT } @@ -320,19 +332,24 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ // Whitelist options and constrain option length. var optlen int switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 } } @@ -354,11 +371,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Whitelist flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that netstack/tcpip/transport/unix doesn't understand. Kill the + // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument @@ -370,6 +387,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddrBuf = make([]byte, sizeofSockaddr) } + var controlBuf []byte var msgFlags int recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -384,11 +402,6 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We always do a non-blocking recv*(). sysflags := flags | syscall.MSG_DONTWAIT - if dsts.NumBlocks() == 1 { - // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) - } - iovs := iovecsFromBlockSeq(dsts) msg := syscall.Msghdr{ Iov: &iovs[0], @@ -398,12 +411,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Name = &senderAddrBuf[0] msg.Namelen = uint32(len(senderAddrBuf)) } + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) + controlLen = uint64(msg.Controllen) return n, nil }) @@ -429,14 +451,38 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err = dst.CopyOutFrom(t, recvmsgToBlocks) } } - - // We don't allow control messages. - msgFlags &^= linux.MSG_CTRUNC + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) + + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } + + controlMessages := socket.ControlMessages{} + for _, unixCmsg := range unixControlMessages { + switch unixCmsg.Header.Level { + case syscall.SOL_IP: + switch unixCmsg.Header.Type { + case syscall.IP_TOS: + controlMessages.IP.HasTOS = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + } + case syscall.SOL_IPV6: + switch unixCmsg.Header.Type { + case syscall.IPV6_TCLASS: + controlMessages.IP.HasTClass = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + } + } + } + + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil } // SendMsg implements socket.Socket.SendMsg. @@ -446,6 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.ErrInvalidArgument } + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { @@ -458,7 +512,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // We always do a non-blocking send*(). sysflags := flags | syscall.MSG_DONTWAIT - if srcs.NumBlocks() == 1 { + if srcs.NumBlocks() == 1 && len(controlBuf) == 0 { // Skip allocating []syscall.Iovec. src := srcs.Head() n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) @@ -477,6 +531,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] msg.Name = &to[0] msg.Namelen = uint32(len(to)) } + if len(controlBuf) != 0 { + msg.Control = &controlBuf[0] + msg.Controllen = uint64(len(controlBuf)) + } return sendmsg(s.fd, &msg, sysflags) }) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3a4fdec47..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -16,8 +16,11 @@ package hostinet import ( "fmt" + "io" "io/ioutil" "os" + "reflect" + "strconv" "strings" "syscall" @@ -26,7 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -51,6 +56,8 @@ type Stack struct { tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool + netDevFile *os.File + netSNMPFile *os.File } // NewStack returns an empty Stack containing no configuration. @@ -98,6 +105,18 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if TCP SACK if enabled, setting to true") } + if f, err := os.Open("/proc/net/dev"); err != nil { + log.Warningf("Failed to open /proc/net/dev: %v", err) + } else { + s.netDevFile = f + } + + if f, err := os.Open("/proc/net/snmp"); err != nil { + log.Warningf("Failed to open /proc/net/snmp: %v", err) + } else { + s.netSNMPFile = f + } + return nil } @@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// getLine reads one line from proc file, with specified prefix. +// The last argument, withHeader, specifies if it contains line header. +func getLine(f *os.File, prefix string, withHeader bool) string { + data := make([]byte, 4096) + + if _, err := f.Seek(0, 0); err != nil { + return "" + } + + if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF { + return "" + } + + prefix = prefix + ":" + lines := strings.Split(string(data), "\n") + for _, l := range lines { + l = strings.TrimSpace(l) + if strings.HasPrefix(l, prefix) { + if withHeader { + withHeader = false + continue + } + return l + } + } + return "" +} + +func toSlice(i interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(i)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserror.EOPNOTSUPP + var ( + snmpTCP bool + rawLine string + sliceStat []uint64 + ) + + switch stat.(type) { + case *inet.StatDev: + if s.netDevFile == nil { + return fmt.Errorf("/proc/net/dev is not opened for hostinet") + } + rawLine = getLine(s.netDevFile, arg, false /* with no header */) + case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite: + if s.netSNMPFile == nil { + return fmt.Errorf("/proc/net/snmp is not opened for hostinet") + } + rawLine = getLine(s.netSNMPFile, arg, true) + default: + return syserr.ErrEndpointOperation.ToError() + } + + if rawLine == "" { + return fmt.Errorf("Failed to get raw line") + } + + parts := strings.SplitN(rawLine, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("Failed to get prefix from: %q", rawLine) + } + + sliceStat = toSlice(stat) + fields := strings.Fields(strings.TrimSpace(parts[1])) + if len(fields) != len(sliceStat) { + return fmt.Errorf("Failed to parse fields: %q", rawLine) + } + if _, ok := stat.(*inet.StatSNMPTCP); ok { + snmpTCP = true + } + for i := 0; i < len(sliceStat); i++ { + var err error + if snmpTCP && i == 3 { + var tmp int64 + // MaxConn field is signed, RFC 2012. + tmp, err = strconv.ParseInt(fields[i], 10, 64) + sliceStat[i] = uint64(tmp) // Convert back to int before use. + } else { + sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) + } + if err != nil { + return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + } + } + + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f95803f91..79589e3c8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 689cad997..be005df24 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -30,6 +30,13 @@ type Protocol interface { // Protocol returns the Linux netlink protocol value. Protocol() int + // CanSend returns true if this protocol may ever send messages. + // + // TODO(gvisor.dev/issue/1119): This is a workaround to allow + // advertising support for otherwise unimplemented features on sockets + // that will never send messages, thus making those features no-ops. + CanSend() bool + // ProcessMessage processes a single message from userspace. // // If err == nil, any messages added to ms will be sent back to the diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index cc70ac237..6b4a0ecf4 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int { return linux.NETLINK_ROUTE } +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return true +} + // dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests. func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index b2732ca29..4a1b87a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -53,6 +54,8 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) +var errNoFilter = syserr.New("no filter attached", linux.ENOENT) + // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() @@ -61,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice() // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -104,9 +107,19 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool + + // filter indicates that this socket has a BPF filter "installed". + // + // TODO(gvisor.dev/issue/1119): We don't actually support filtering, + // this is just bookkeeping for tracking add/remove. + filter bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -172,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *Socket) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *Socket) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -309,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. // We don't have limit on receiving size. return int32(math.MaxInt32), nil + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred int32 + if s.Passcred() { + passcred = 1 + } + return passcred, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -348,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -355,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + + case linux.SO_ATTACH_FILTER: + // TODO(gvisor.dev/issue/1119): We don't actually + // support filtering. If this socket can't ever send + // messages, then there is nothing to filter and we can + // advertise support. Otherwise, be conservative and + // return an error. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + s.filter = true + s.mu.Unlock() + return nil + + case linux.SO_DETACH_FILTER: + // TODO(gvisor.dev/issue/1119): See above. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + filter := s.filter + s.filter = false + s.mu.Unlock() + + if !filter { + return errNoFilter + } + + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -483,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. @@ -491,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -521,7 +633,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD new file mode 100644 index 000000000..0777f3baf --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "uevent", + srcs = ["protocol.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netlink", + "//pkg/syserr", + ], +) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go new file mode 100644 index 000000000..b5d7808d7 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol. +// +// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does +// not support any device events, so these sockets never send any messages. +package uevent + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/syserr" +) + +// Protocol implements netlink.Protocol. +// +// +stateify savable +type Protocol struct{} + +var _ netlink.Protocol = (*Protocol)(nil) + +// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol. +func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { + return &Protocol{}, nil +} + +// Protocol implements netlink.Protocol.Protocol. +func (p *Protocol) Protocol() int { + return linux.NETLINK_KOBJECT_UEVENT +} + +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return false +} + +// ProcessMessage implements netlink.Protocol.ProcessMessage. +func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { + // Silently ignore all messages. + return nil +} + +// init registers the NETLINK_KOBJECT_UEVENT provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol) +} diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0ae573b45..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -53,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -148,6 +149,10 @@ var Metrics = tcpip.Stats{ TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -278,7 +283,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -296,6 +301,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) +var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -307,12 +313,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address { } // AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and -// AF_INET6 addresses. +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. // // strict indicates whether addresses with the AF_UNSPEC family are accepted of not. // -// AddressAndFamily returns an address, its family. +// AddressAndFamily returns an address and its family. func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { @@ -320,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { + if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported } @@ -371,6 +377,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } return out, family, nil + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(b/129292371): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + case linux.AF_UNSPEC: return tcpip.FullAddress{}, family, nil @@ -1035,8 +1057,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - var v tcpip.DelayOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1105,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Millisecond), nil + case linux.TCP_INFO: var v tcpip.TCPInfoOption if err := ep.GetSockOpt(&v); err != nil { @@ -1153,6 +1187,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1477,11 +1523,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o tcpip.DelayOption + var o int if v == 0 { o = 1 } - return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1529,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { @@ -1536,6 +1593,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1951,12 +2016,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(2 + l) } return &out, uint32(3 + l) + case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInetSize) + case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1972,7 +2039,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(b/129292371): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + default: return nil, 0 } diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 357a664cc..2d2c1ba2a 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -62,6 +62,10 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: + // TODO(b/142504697): "In order to create a raw socket, a + // process must have the CAP_NET_RAW capability in the user + // namespace that governs its network namespace." - raw(7) + // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in return 0, true, syserr.ErrProtocolNotSupported } -// Socket creates a new socket object for the AF_INET or AF_INET6 family. +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() @@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, nil } + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocket(t, eps, stype, protocol) + } + // Figure out the transport protocol. transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { @@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return New(t, p.family, stype, int(transProto), wq, ep) } +func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // TODO(b/142504697): "In order to create a packet socket, a process + // must have the CAP_NET_RAW capability in the user namespace that + // governs its network namespace." - packet(7) + + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return New(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } -// init registers socket providers for AF_INET and AF_INET6. +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. func init() { // Providers backed by netstack. p := []provider{ @@ -138,6 +184,9 @@ func init() { family: linux.AF_INET6, netProto: ipv6.ProtocolNumber, }, + { + family: linux.AF_PACKET, + }, } for i := range p { diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fda0156e5..a0db2d4fd 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 3a6baa308..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto index 9586f5923..b677e9eb3 100644 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto @@ -3,7 +3,6 @@ syntax = "proto3"; // package syscall_rpc is a set of networking related system calls that can be // forwarded to a socket gofer. // -// TODO(b/77963526): Document individual RPCs. package syscall_rpc; message SendmsgRequest { diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..2389a9cdb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -43,6 +43,11 @@ type ControlMessages struct { IP tcpip.ControlMessages } +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release() { + c.Unix.Release() +} + // Socket is the interface containing socket syscalls used by the syscall layer // to redirect them to the appropriate implementation. type Socket interface { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..885758054 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -118,6 +118,9 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { func extractPath(sockaddr []byte) (string, *syserr.Error) { addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 72ebf766d..d46421199 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -14,6 +14,7 @@ go_library( "open.go", "poll.go", "ptrace.go", + "select.go", "signal.go", "socket.go", "strace.go", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 5d57b75af..e603f858f 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{ 20: makeSyscallInfo("writev", FD, WriteIOVec, Hex), 21: makeSyscallInfo("access", Path, Oct), 22: makeSyscallInfo("pipe", PipeFDs), - 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval), + 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval), 24: makeSyscallInfo("sched_yield"), 25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex), 26: makeSyscallInfo("msync", Hex, Hex, Hex), @@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{ 267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex), 268: makeSyscallInfo("fchmodat", FD, Path, Mode), 269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex), - 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex), + 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet), 271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex), 272: makeSyscallInfo("unshare", CloneFlags), 273: makeSyscallInfo("set_robust_list", Hex, Hex), @@ -335,5 +335,33 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 318: makeSyscallInfo("getrandom", Hex, Hex, Hex), + 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close. + 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex), + 321: makeSyscallInfo("bpf", Hex, Hex, Hex), + 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex), + 323: makeSyscallInfo("userfaultfd", Hex), + 324: makeSyscallInfo("membarrier", Hex, Hex), + 325: makeSyscallInfo("mlock2", Hex, Hex, Hex), + 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex), + 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex), + 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex), + 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex), + 330: makeSyscallInfo("pkey_alloc", Hex, Hex), + 331: makeSyscallInfo("pkey_free", Hex), 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), + 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet), + 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex), + 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex), + 425: makeSyscallInfo("io_uring_setup", Hex, Hex), + 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex), + 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex), + 428: makeSyscallInfo("open_tree", FD, Path, Hex), + 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex), + 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close. + 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex), + 432: makeSyscallInfo("fsmount", FD, Hex, Hex), + 433: makeSyscallInfo("fspick", FD, Path, Hex), + 434: makeSyscallInfo("pidfd_open", Hex, Hex), + 435: makeSyscallInfo("clone3", Hex, Hex), } diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go new file mode 100644 index 000000000..c77d418e6 --- /dev/null +++ b/pkg/sentry/strace/select.go @@ -0,0 +1,56 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package strace + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/usermem" +) + +func fdsFromSet(t *kernel.Task, set []byte) []int { + var fds []int + // Append n if the n-th bit is 1. + for i, v := range set { + for j := 0; j < 8; j++ { + if (v>>j)&1 == 1 { + fds = append(fds, i*8+j) + } + } + } + return fds +} + +func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string { + if nfds < 0 { + return fmt.Sprintf("%#x (negative nfds)", addr) + } + if addr == 0 { + return "null" + } + + // Calculate the size of the fd set (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 + + set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte) + if err != nil { + return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err) + } + + return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set)) +} diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 94334f6d2..51f2efb39 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) i += linux.SizeOfControlMessageHeader width := t.Arch().Width() length := int(h.Length) - linux.SizeOfControlMessageHeader + if length < 0 { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 311389547..629c1f308 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer())) case PollFDs: output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false)) + case SelectFDSet: + output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer())) case Oct: output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8)) case Hex: diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto index 4b2f73a5f..906c52c51 100644 --- a/pkg/sentry/strace/strace.proto +++ b/pkg/sentry/strace/strace.proto @@ -32,8 +32,7 @@ message Strace { } } -message StraceEnter { -} +message StraceEnter {} message StraceExit { // Return value formatted as string. diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go index 3c389d375..e5d486c4e 100644 --- a/pkg/sentry/strace/syscalls.go +++ b/pkg/sentry/strace/syscalls.go @@ -206,6 +206,10 @@ const ( // PollFDs is an array of struct pollfd. The number of entries in the // array is in the next argument. PollFDs + + // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of + // fds represented must be the first argument. + SelectFDSet ) // defaultFormat is the syscall argument format to use if the actual format is diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index fb2c1777f..6766ba587 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -49,6 +49,7 @@ go_library( "sys_tls.go", "sys_utsname.go", "sys_write.go", + "sys_xattr.go", "timespec.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", @@ -79,6 +80,7 @@ go_library( "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", + "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/safemem", diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go index 444f2b004..07961dad9 100644 --- a/pkg/sentry/syscalls/linux/flags.go +++ b/pkg/sentry/syscalls/linux/flags.go @@ -50,5 +50,6 @@ func linuxToFlags(mask uint) fs.FileFlags { Directory: mask&linux.O_DIRECTORY != 0, Async: mask&linux.O_ASYNC != 0, LargeFile: mask&linux.O_LARGEFILE != 0, + Truncate: mask&linux.O_TRUNC != 0, } } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index b317cb99d..68589a377 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -16,7 +16,12 @@ package linux const ( - _LINUX_SYSNAME = "Linux" - _LINUX_RELEASE = "4.4.0" - _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016" + // LinuxSysname is the OS name advertised by gVisor. + LinuxSysname = "Linux" + + // LinuxRelease is the Linux release version number advertised by gVisor. + LinuxRelease = "4.4.0" + + // LinuxVersion is the version info advertised by gVisor. + LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016" ) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go index e215ac049..272ae9991 100644 --- a/pkg/sentry/syscalls/linux/linux64_amd64.go +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -34,9 +34,9 @@ var AMD64 = &kernel.SyscallTable{ // guides the interface provided by this syscall table. The build // version is that for a clean build with default kernel config, at 5 // minutes after v4.4 was tagged. - Sysname: _LINUX_SYSNAME, - Release: _LINUX_RELEASE, - Version: _LINUX_VERSION, + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, }, AuditNumber: linux.AUDIT_ARCH_X86_64, Table: map[uintptr]kernel.Syscall{ @@ -228,10 +228,10 @@ var AMD64 = &kernel.SyscallTable{ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), 187: syscalls.Supported("readahead", Readahead), - 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -260,7 +260,7 @@ var AMD64 = &kernel.SyscallTable{ 217: syscalls.Supported("getdents64", Getdents64), 218: syscalls.Supported("set_tid_address", SetTidAddress), 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), 222: syscalls.Supported("timer_create", TimerCreate), 223: syscalls.Supported("timer_settime", TimerSettime), @@ -362,16 +362,36 @@ var AMD64 = &kernel.SyscallTable{ 319: syscalls.Supported("memfd_create", MemfdCreate), 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 322: syscalls.Supported("execveat", Execveat), 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - // Syscalls after 325 are "backports" from versions of Linux after 4.4. + // Syscalls implemented after 325 are "backports" from versions + // of Linux after 4.4. 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 327: syscalls.Supported("preadv2", Preadv2), 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 332: syscalls.Supported("statx", Statx), + 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{ diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go index f82dfac31..3b584eed9 100644 --- a/pkg/sentry/syscalls/linux/linux64_arm64.go +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -30,9 +30,9 @@ var ARM64 = &kernel.SyscallTable{ OS: abi.Linux, Arch: arch.ARM64, Version: kernel.Version{ - Sysname: _LINUX_SYSNAME, - Release: _LINUX_RELEASE, - Version: _LINUX_VERSION, + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, }, AuditNumber: linux.AUDIT_ARCH_AARCH64, Table: map[uintptr]kernel.Syscall{ @@ -41,10 +41,10 @@ var ARM64 = &kernel.SyscallTable{ 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil), 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil), 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), @@ -224,7 +224,7 @@ var ARM64 = &kernel.SyscallTable{ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), @@ -295,14 +295,33 @@ var ARM64 = &kernel.SyscallTable{ 278: syscalls.Supported("getrandom", GetRandom), 279: syscalls.Supported("memfd_create", MemfdCreate), 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 281: syscalls.Supported("execveat", Execveat), 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) + 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), 286: syscalls.Supported("preadv2", Preadv2), 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil), + 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil), + 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil), 291: syscalls.Supported("statx", Statx), + 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil), + 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil), + + // Linux skips ahead to syscall 424 to sync numbers between arches. + 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil), + 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil), + 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil), + 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil), + 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil), + 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil), + 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil), + 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil), + 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil), + 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), + 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), + 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), }, Emulate: map[usermem.Addr]uintptr{}, diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index b9a8e3e21..9bc2445a5 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -169,10 +169,14 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint if dirPath { return syserror.ENOTDIR } - if flags&linux.O_TRUNC != 0 { - if err := d.Inode.Truncate(t, d, 0); err != nil { - return err - } + } + + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. + if flags&linux.O_TRUNC != 0 { + if err := d.Inode.Truncate(t, d, 0); err != nil { + return err } } @@ -396,7 +400,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l return err } - // Should we truncate the file? + // Truncate is called when O_TRUNC is specified for any kind of + // existing Dirent. Behavior is delegated to the entry's Truncate + // implementation. if flags&linux.O_TRUNC != 0 { if err := found.Inode.Truncate(t, found, 0); err != nil { return err @@ -834,23 +840,40 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(newfd), nil, nil } -func fGetOwn(t *kernel.Task, file *fs.File) int32 { +func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx { ma := file.Async(nil) if ma == nil { - return 0 + return linux.FOwnerEx{} } a := ma.(*fasync.FileAsync) ot, otg, opg := a.Owner() switch { case ot != nil: - return int32(t.PIDNamespace().IDOfTask(ot)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + } case otg != nil: - return int32(t.PIDNamespace().IDOfThreadGroup(otg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + } case opg != nil: - return int32(-t.PIDNamespace().IDOfProcessGroup(opg)) + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + } default: - return 0 + return linux.FOwnerEx{} + } +} + +func fGetOwn(t *kernel.Task, file *fs.File) int32 { + owner := fGetOwnEx(t, file) + if owner.Type == linux.F_OWNER_PGRP { + return -owner.PID } + return owner.PID } // fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux. @@ -895,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall t.FDTable().SetFlags(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) + return 0, nil, nil case linux.F_GETFL: return uintptr(file.Flags().ToLinux()), nil, nil case linux.F_SETFL: flags := uint(args[2].Uint()) file.SetFlags(linuxToFlags(flags).Settable()) + return 0, nil, nil case linux.F_SETLK, linux.F_SETLKW: // In Linux the file system can choose to provide lock operations for an inode. // Normally pipe and socket types lack lock operations. We diverge and use a heavy @@ -1002,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETOWN: fSetOwn(t, file, args[2].Int()) return 0, nil, nil + case linux.F_GETOWN_EX: + addr := args[2].Pointer() + owner := fGetOwnEx(t, file) + _, err := t.CopyOut(addr, &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + addr := args[2].Pointer() + var owner linux.FOwnerEx + n, err := t.CopyIn(addr, &owner) + if err != nil { + return 0, nil, err + } + a := file.Async(fasync.New).(*fasync.FileAsync) + switch owner.Type { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) + if task == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerTask(t, task) + return uintptr(n), nil, nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) + if tg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return uintptr(n), nil, nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) + if pg == nil { + return 0, nil, syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return uintptr(n), nil, nil + default: + return 0, nil, syserror.EINVAL + } case linux.F_GET_SEALS: val, err := tmpfs.GetSeals(file.Dirent.Inode) return uintptr(val), nil, err @@ -1029,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Everything else is not yet supported. return 0, nil, syserror.EINVAL } - return 0, nil, nil } const ( @@ -1483,6 +1545,8 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if fs.IsDir(d.Inode.StableAttr) { return syserror.EISDIR } + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. if !fs.IsFile(d.Inode.StableAttr) { return syserror.EINVAL } @@ -1521,7 +1585,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Note that this is different from truncate(2) above, where a + // In contrast to open(O_TRUNC), truncate(2) is only valid for file + // types. Note that this is different from truncate(2) above, where a // directory returns EISDIR. if !fs.IsFile(file.Dirent.Inode.StableAttr) { return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b9bd25464..bde17a767 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if mask == 0 { return 0, nil, syserror.EINVAL } + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().Wake(t, addr, private, mask, val) return uintptr(n), nil, err @@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FUTEX_WAKE_OP: op := uint32(val3) + if val <= 0 { + // The Linux kernel wakes one waiter even if val is + // non-positive. + val = 1 + } n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op) return uintptr(n), nil, err diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 7a13beac2..2b2df989a 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) return remainingTimeout, n, err } +// CopyInFDSet copies an fd set from select(2)/pselect(2). +func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) { + set := make([]byte, nBytes) + + if addr != 0 { + if _, err := t.CopyIn(addr, &set); err != nil { + return nil, err + } + // If we only use part of the last byte, mask out the extraneous bits. + // + // N.B. This only works on little-endian architectures. + if nBitsInLastPartialByte != 0 { + set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte + } + } + return set, nil +} + func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) { if nfds < 0 || nfds > fileCap { return 0, syserror.EINVAL } - // Capture all the provided input vectors. - // - // N.B. This only works on little-endian architectures. - byteCount := (nfds + 7) / 8 - - bitsInLastPartialByte := uint(nfds % 8) - r := make([]byte, byteCount) - w := make([]byte, byteCount) - e := make([]byte, byteCount) + // Calculate the size of the fd sets (one bit per fd). + nBytes := (nfds + 7) / 8 + nBitsInLastPartialByte := nfds % 8 - if readFDs != 0 { - if _, err := t.CopyIn(readFDs, &r); err != nil { - return 0, err - } - // Mask out bits above nfds. - if bitsInLastPartialByte != 0 { - r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + // Capture all the provided input vectors. + r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if writeFDs != 0 { - if _, err := t.CopyIn(writeFDs, &w); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } - - if exceptFDs != 0 { - if _, err := t.CopyIn(exceptFDs, &e); err != nil { - return 0, err - } - if bitsInLastPartialByte != 0 { - e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte - } + e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte) + if err != nil { + return 0, err } // Count how many FDs are actually being requested so that we can build // a PollFD array. fdCount := 0 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { v := r[i] | w[i] | e[i] for v != 0 { v &= (v - 1) @@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Build the PollFD array. pfd := make([]linux.PollFD, 0, fdCount) var fd int32 - for i := 0; i < byteCount; i++ { + for i := 0; i < nBytes; i++ { rV, wV, eV := r[i], w[i], e[i] v := rV | wV | eV m := byte(1) @@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add } // Do the syscall, then count the number of bits set. - _, _, err := pollBlock(t, pfd, timeout) - if err != nil { + if _, _, err = pollBlock(t, pfd, timeout); err != nil { return 0, syserror.ConvertIntr(err, syserror.EINTR) } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index b5a72ce63..4b5aafcc0 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -447,16 +447,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - // Read the length if present. Reject negative values. + // Read the length. Reject negative values. optLen := int32(0) - if optLenAddr != 0 { - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { - return 0, nil, err - } - - if optLen < 0 { - return 0, nil, syserror.EINVAL - } + if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + return 0, nil, err + } + if optLen < 0 { + return 0, nil, syserror.EINVAL } // Call syscall implementation then copy both value and value len out. @@ -465,11 +462,9 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - if optLenAddr != 0 { - vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { - return 0, nil, err - } + vLen := int32(binary.Size(v)) + if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + return 0, nil, err } if v != nil { @@ -769,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC - cms.Unix.Release() + cms.Release() } if int(msg.Flags) != mflags { @@ -789,24 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } - defer cms.Unix.Release() + defer cms.Release() controlData := make([]byte, 0, msg.ControlLen) + controlData = control.PackControlMessages(t, cms, controlData) if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() { creds, _ := cms.Unix.Credentials.(control.SCMCredentials) controlData, mflags = control.PackCredentials(t, creds, controlData, mflags) } - if cms.IP.HasTimestamp { - controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData) - } - - if cms.IP.HasInq { - // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. - controlData = control.PackInq(t, cms.IP.Inq, controlData) - } - if cms.Unix.Rights != nil { controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags) } @@ -882,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag } n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) - cm.Unix.Release() + cm.Release() if e != nil { return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) } @@ -1065,7 +1052,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme } // Call the syscall implementation. - n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages}) + n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release() diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 8ab7ffa25..4115116ff 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -15,12 +15,15 @@ package linux import ( + "path" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -67,8 +70,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal argvAddr := args[1].Pointer() envvAddr := args[2].Pointer() - // Extract our arguments. - filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX) + return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirFD := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + + return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) if err != nil { return 0, nil, err } @@ -89,14 +106,66 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 + if !atEmptyPath && len(pathname) == 0 { + return 0, nil, syserror.ENOENT + } + resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 + root := t.FSContext().RootDirectory() defer root.DecRef() - wd := t.FSContext().WorkingDirectory() - defer wd.DecRef() + + var wd *fs.Dirent + var executable *fs.File + var closeOnExec bool + if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) { + // Even if the pathname is absolute, we may still need the wd + // for interpreter scripts if the path of the interpreter is + // relative. + wd = t.FSContext().WorkingDirectory() + } else { + // Need to extract the given FD. + f, fdFlags := t.FDTable().Get(dirFD) + if f == nil { + return 0, nil, syserror.EBADF + } + defer f.DecRef() + closeOnExec = fdFlags.CloseOnExec + + if atEmptyPath && len(pathname) == 0 { + executable = f + } else { + wd = f.Dirent + wd.IncRef() + if !fs.IsDir(wd.Inode.StableAttr) { + return 0, nil, syserror.ENOTDIR + } + } + } + if wd != nil { + defer wd.DecRef() + } // Load the new TaskContext. - maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet()) + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: t.MountNamespace(), + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: resolveFinal, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 27cd2c336..ad4b67806 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Pwritev2 implements linux syscall pwritev2(2). -// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality. func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go new file mode 100644 index 000000000..97d9a65ea --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_xattr.go @@ -0,0 +1,169 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Getxattr implements linux syscall getxattr(2). +func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + valueLen := 0 + err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + value, err := getxattr(t, d, dirPath, nameAddr) + if err != nil { + return err + } + + valueLen = len(value) + if size == 0 { + return nil + } + if size > linux.XATTR_SIZE_MAX { + size = linux.XATTR_SIZE_MAX + } + if valueLen > int(size) { + return syserror.ERANGE + } + + _, err = t.CopyOutBytes(valueAddr, []byte(value)) + return err + }) + if err != nil { + return 0, nil, err + } + return uintptr(valueLen), nil, nil +} + +// getxattr implements getxattr from the given *fs.Dirent. +func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return "", syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil { + return "", err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return "", err + } + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return "", syserror.EOPNOTSUPP + } + + return d.Inode.Getxattr(name) +} + +// Setxattr implements linux syscall setxattr(2). +func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + pathAddr := args[0].Pointer() + nameAddr := args[1].Pointer() + valueAddr := args[2].Pointer() + size := args[3].SizeT() + flags := args[4].Uint() + + path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */) + if err != nil { + return 0, nil, err + } + + if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 { + return 0, nil, syserror.EINVAL + } + + return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error { + return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags) + }) +} + +// setxattr implements setxattr from the given *fs.Dirent. +func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error { + if dirPath && !fs.IsDir(d.Inode.StableAttr) { + return syserror.ENOTDIR + } + + if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil { + return err + } + + name, err := copyInXattrName(t, nameAddr) + if err != nil { + return err + } + + if size > linux.XATTR_SIZE_MAX { + return syserror.E2BIG + } + buf := make([]byte, size) + if _, err = t.CopyInBytes(valueAddr, buf); err != nil { + return err + } + value := string(buf) + + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + + return d.Inode.Setxattr(name, value) +} + +func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { + name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1) + if err != nil { + if err == syserror.ENAMETOOLONG { + return "", syserror.ERANGE + } + return "", err + } + if len(name) == 0 { + return "", syserror.ERANGE + } + return name, nil +} + +func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error { + // Restrict xattrs to regular files and directories. + // + // In Linux, this restriction technically only applies to xattrs in the + // "user.*" namespace, but we don't allow any other xattr prefixes anyway. + if !fs.IsRegular(i.StableAttr) && !fs.IsDir(i.StableAttr) { + if perms.Write { + return syserror.EPERM + } + return syserror.ENODATA + } + + return i.CheckPermission(t, perms) +} diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index d3a4cd943..18e212dff 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//third_party/gvsync:generic_seqatomic", + template = "//pkg/syncutil:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,8 +36,8 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", + "//pkg/syncutil", "//pkg/syserror", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go index 8d88396ba..7898851b3 100644 --- a/pkg/sentry/usermem/bytes_io.go +++ b/pkg/sentry/usermem/bytes_io.go @@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) { } func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) { - blocks := make([]safemem.Block, 0, ars.NumRanges()) - for !ars.IsEmpty() { - ar := ars.Head() - n, err := b.rangeCheck(ar.Start, int(ar.Length())) - if n != 0 { - blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n])) + switch ars.NumRanges() { + case 0: + return safemem.BlockSeq{}, nil + case 1: + block, err := b.blockFromAddrRange(ars.Head()) + return safemem.BlockSeqOf(block), err + default: + blocks := make([]safemem.Block, 0, ars.NumRanges()) + for !ars.IsEmpty() { + block, err := b.blockFromAddrRange(ars.Head()) + if block.Len() != 0 { + blocks = append(blocks, block) + } + if err != nil { + return safemem.BlockSeqFromSlice(blocks), err + } + ars = ars.Tail() } - if err != nil { - return safemem.BlockSeqFromSlice(blocks), err - } - ars = ars.Tail() + return safemem.BlockSeqFromSlice(blocks), nil + } +} + +func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) { + n, err := b.rangeCheck(ar.Start, int(ar.Length())) + if n == 0 { + return safemem.Block{}, err } - return safemem.BlockSeqFromSlice(blocks), nil + return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err } // BytesIOSequence returns an IOSequence representing the given byte slice. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index eff4b44f6..e3e554b88 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -12,13 +12,14 @@ go_library( "file_description.go", "file_description_impl_util.go", "filesystem.go", + "filesystem_impl_util.go", "filesystem_type.go", "mount.go", "mount_unsafe.go", "options.go", + "pathname.go", "permissions.go", "resolving_path.go", - "syscalls.go", "testutil.go", "vfs.go", ], @@ -32,9 +33,9 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/syncutil", "//pkg/syserror", "//pkg/waiter", - "//third_party/gvsync", ], ) diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 7847854bc..9aa133bcb 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -39,8 +39,8 @@ Mount references are held by: - Mount: Each referenced Mount holds a reference on its parent, which is the mount containing its mount point. -- VirtualFilesystem: A reference is held on all Mounts that are attached - (reachable by Mount traversal). +- VirtualFilesystem: A reference is held on each Mount that has not been + umounted. MountNamespace and FileDescription references are held by users of VFS. The expectation is that each `kernel.Task` holds a reference on its corresponding diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index 32cf9151b..705194ebc 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -24,6 +24,9 @@ type contextID int const ( // CtxMountNamespace is a Context.Value key for a MountNamespace. CtxMountNamespace contextID = iota + + // CtxRoot is a Context.Value key for a VFS root. + CtxRoot ) // MountNamespaceFromContext returns the MountNamespace used by ctx. It does @@ -35,3 +38,13 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace { } return nil } + +// RootFromContext returns the VFS root used by ctx. It takes a reference on +// the returned VirtualDentry. If ctx does not have a specific VFS root, +// RootFromContext returns a zero-value VirtualDentry. +func RootFromContext(ctx context.Context) VirtualDentry { + if v := ctx.Value(CtxRoot); v != nil { + return v.(VirtualDentry) + } + return VirtualDentry{} +} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 45912fc58..6209eb053 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,6 +16,7 @@ package vfs import ( "fmt" + "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/syserror" @@ -50,7 +51,7 @@ import ( // and not inodes. Furthermore, when parties outside the scope of VFS can // rename inodes on such filesystems, VFS generally cannot "follow" the rename, // both due to synchronization issues and because it may not even be able to -// name the destination path; this implies that it would in fact be *incorrect* +// name the destination path; this implies that it would in fact be incorrect // for Dentries to be associated with inodes on such filesystems. Consequently, // operations that are inode operations in Linux are FilesystemImpl methods // and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do @@ -87,6 +88,9 @@ type Dentry struct { // children are child Dentries. children map[string]*Dentry + // mu synchronizes disowning and mounting over this Dentry. + mu sync.Mutex + // impl is the DentryImpl associated with this Dentry. impl is immutable. // This should be the last field in Dentry. impl DentryImpl @@ -114,7 +118,7 @@ func (d *Dentry) Impl() DentryImpl { type DentryImpl interface { // IncRef increments the Dentry's reference count. A Dentry with a non-zero // reference count must remain coherent with the state of the filesystem. - IncRef(fs *Filesystem) + IncRef() // TryIncRef increments the Dentry's reference count and returns true. If // the Dentry's reference count is zero, TryIncRef may do nothing and @@ -122,10 +126,10 @@ type DentryImpl interface { // guarantee that the Dentry is coherent with the state of the filesystem.) // // TryIncRef does not require that a reference is held on the Dentry. - TryIncRef(fs *Filesystem) bool + TryIncRef() bool // DecRef decrements the Dentry's reference count. - DecRef(fs *Filesystem) + DecRef() } // IsDisowned returns true if d is disowned. @@ -142,16 +146,20 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } -func (d *Dentry) incRef(fs *Filesystem) { - d.impl.IncRef(fs) +// IncRef increments d's reference count. +func (d *Dentry) IncRef() { + d.impl.IncRef() } -func (d *Dentry) tryIncRef(fs *Filesystem) bool { - return d.impl.TryIncRef(fs) +// TryIncRef increments d's reference count and returns true. If d's reference +// count is zero, TryIncRef may instead do nothing and return false. +func (d *Dentry) TryIncRef() bool { + return d.impl.TryIncRef() } -func (d *Dentry) decRef(fs *Filesystem) { - d.impl.DecRef(fs) +// DecRef decrements d's reference count. +func (d *Dentry) DecRef() { + d.impl.DecRef() } // These functions are exported so that filesystem implementations can use @@ -191,6 +199,18 @@ func (d *Dentry) HasChildren() bool { return len(d.children) != 0 } +// Children returns a map containing all of d's children. +func (d *Dentry) Children() map[string]*Dentry { + if !d.HasChildren() { + return nil + } + m := make(map[string]*Dentry) + for name, child := range d.children { + m[name] = child + } + return m +} + // InsertChild makes child a child of d with the given name. // // InsertChild is a mutator of d and child. @@ -228,36 +248,48 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent panic("d is already disowned") } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[d]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[d] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } - // Return with vfs.mountMu locked, which will be unlocked by - // AbortDeleteDentry or CommitDeleteDentry. + d.mu.Lock() + vfs.mountMu.Unlock() + // Return with d.mu locked to block attempts to mount over it; it will be + // unlocked by AbortDeleteDentry or CommitDeleteDentry. return nil } // AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion // fails. -func (vfs *VirtualFilesystem) AbortDeleteDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) { + d.mu.Unlock() } // CommitDeleteDentry must be called after the file represented by d is // deleted, and causes d to become disowned. // +// CommitDeleteDentry is a mutator of d and d.Parent(). +// // Preconditions: PrepareDeleteDentry was previously called on d. func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) { - delete(d.parent.children, d.name) + if d.parent != nil { + delete(d.parent.children, d.name) + } d.setDisowned() - // TODO: lazily unmount mounts at d - vfs.mountMu.RUnlock() + d.mu.Unlock() + if d.isMounted() { + vfs.forgetDisownedMountpoint(d) + } } // DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as // appropriate for in-memory filesystems that don't need to ensure that some // external state change succeeds before committing the deletion. +// +// DeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error { if err := vfs.PrepareDeleteDentry(mntns, d); err != nil { return err @@ -266,6 +298,27 @@ func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) err return nil } +// ForceDeleteDentry causes d to become disowned. It should only be used in +// cases where VFS has no ability to stop the deletion (e.g. d represents the +// local state of a file on a remote filesystem on which the file has already +// been deleted). +// +// ForceDeleteDentry is a mutator of d and d.Parent(). +// +// Preconditions: d is a child Dentry. +func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) { + if checkInvariants { + if d.parent == nil { + panic("d is independent") + } + if d.IsDisowned() { + panic("d is already disowned") + } + } + d.mu.Lock() + vfs.CommitDeleteDentry(d) +} + // PrepareRenameDentry must be called before attempting to rename the file // represented by from. If to is not nil, it represents the file that will be // replaced or exchanged by the rename. If PrepareRenameDentry succeeds, the @@ -291,18 +344,21 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t } } } - vfs.mountMu.RLock() - if _, ok := mntns.mountpoints[from]; ok { - vfs.mountMu.RUnlock() + vfs.mountMu.Lock() + if mntns.mountpoints[from] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } if to != nil { - if _, ok := mntns.mountpoints[to]; ok { - vfs.mountMu.RUnlock() + if mntns.mountpoints[to] != 0 { + vfs.mountMu.Unlock() return syserror.EBUSY } + to.mu.Lock() } - // Return with vfs.mountMu locked, which will be unlocked by + from.mu.Lock() + vfs.mountMu.Unlock() + // Return with from.mu and to.mu locked, which will be unlocked by // AbortRenameDentry, CommitRenameReplaceDentry, or // CommitRenameExchangeDentry. return nil @@ -310,38 +366,76 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t // AbortRenameDentry must be called after PrepareRenameDentry if the rename // fails. -func (vfs *VirtualFilesystem) AbortRenameDentry() { - vfs.mountMu.RUnlock() +func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) { + from.mu.Unlock() + if to != nil { + to.mu.Unlock() + } } // CommitRenameReplaceDentry must be called after the file represented by from // is renamed without RENAME_EXCHANGE. If to is not nil, it represents the file // that was replaced by from. // +// CommitRenameReplaceDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. // newParent.Child(newName) == to. func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry, newName string, to *Dentry) { - if to != nil { - to.setDisowned() - // TODO: lazily unmount mounts at d - } if newParent.children == nil { newParent.children = make(map[string]*Dentry) } newParent.children[newName] = from from.parent = newParent from.name = newName - vfs.mountMu.RUnlock() + from.mu.Unlock() + if to != nil { + to.setDisowned() + to.mu.Unlock() + if to.isMounted() { + vfs.forgetDisownedMountpoint(to) + } + } } // CommitRenameExchangeDentry must be called after the files represented by // from and to are exchanged by rename(RENAME_EXCHANGE). // +// CommitRenameExchangeDentry is a mutator of from, to, from.Parent(), and +// to.Parent(). +// // Preconditions: PrepareRenameDentry was previously called on from and to. func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) { from.parent, to.parent = to.parent, from.parent from.name, to.name = to.name, from.name from.parent.children[from.name] = from to.parent.children[to.name] = to - vfs.mountMu.RUnlock() + from.mu.Unlock() + to.mu.Unlock() +} + +// forgetDisownedMountpoint is called when a mount point is deleted to umount +// all mounts using it in all other mount namespaces. +// +// forgetDisownedMountpoint is analogous to Linux's +// fs/namespace.c:__detach_mounts(). +func (vfs *VirtualFilesystem) forgetDisownedMountpoint(d *Dentry) { + var ( + vdsToDecRef []VirtualDentry + mountsToDecRef []*Mount + ) + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + for mnt := range vfs.mountpoints[d] { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef) + } + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 3a9665800..df03886c3 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -20,8 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -38,43 +40,64 @@ type FileDescription struct { // operations. refs int64 + // statusFlags contains status flags, "initialized by open(2) and possibly + // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic + // memory operations. + statusFlags uint32 + // vd is the filesystem location at which this FileDescription was opened. // A reference is held on vd. vd is immutable. vd VirtualDentry + opts FileDescriptionOptions + // impl is the FileDescriptionImpl associated with this Filesystem. impl is // immutable. This should be the last field in FileDescription. impl FileDescriptionImpl } -// Init must be called before first use of fd. It takes references on mnt and -// d. -func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) { +// FileDescriptionOptions contains options to FileDescription.Init(). +type FileDescriptionOptions struct { + // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is + // usually only the case if O_DIRECT would actually have an effect. + AllowDirectIO bool +} + +// Init must be called before first use of fd. It takes ownership of references +// on mnt and d held by the caller. statusFlags is the initial file description +// status flags, which is usually the full set of flags passed to open(2). +func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) { fd.refs = 1 + fd.statusFlags = statusFlags | linux.O_LARGEFILE fd.vd = VirtualDentry{ mount: mnt, dentry: d, } - fd.vd.IncRef() + fd.opts = *opts fd.impl = impl } -// Impl returns the FileDescriptionImpl associated with fd. -func (fd *FileDescription) Impl() FileDescriptionImpl { - return fd.impl -} - -// VirtualDentry returns the location at which fd was opened. It does not take -// a reference on the returned VirtualDentry. -func (fd *FileDescription) VirtualDentry() VirtualDentry { - return fd.vd -} - // IncRef increments fd's reference count. func (fd *FileDescription) IncRef() { atomic.AddInt64(&fd.refs, 1) } +// TryIncRef increments fd's reference count and returns true. If fd's +// reference count is already zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fd. +func (fd *FileDescription) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fd.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) { + return true + } + } +} + // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef() { if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { @@ -85,6 +108,82 @@ func (fd *FileDescription) DecRef() { } } +// Mount returns the mount on which fd was opened. It does not take a reference +// on the returned Mount. +func (fd *FileDescription) Mount() *Mount { + return fd.vd.mount +} + +// Dentry returns the dentry at which fd was opened. It does not take a +// reference on the returned Dentry. +func (fd *FileDescription) Dentry() *Dentry { + return fd.vd.dentry +} + +// VirtualDentry returns the location at which fd was opened. It does not take +// a reference on the returned VirtualDentry. +func (fd *FileDescription) VirtualDentry() VirtualDentry { + return fd.vd +} + +// StatusFlags returns file description status flags, as for fcntl(F_GETFL). +func (fd *FileDescription) StatusFlags() uint32 { + return atomic.LoadUint32(&fd.statusFlags) +} + +// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL). +func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error { + // Compare Linux's fs/fcntl.c:setfl(). + oldFlags := fd.StatusFlags() + // Linux documents this check as "O_APPEND cannot be cleared if the file is + // marked as append-only and the file is open for write", which would make + // sense. However, the check as actually implemented seems to be "O_APPEND + // cannot be changed if the file is marked as append-only". + if (flags^oldFlags)&linux.O_APPEND != 0 { + stat, err := fd.impl.Stat(ctx, StatOptions{ + // There is no mask bit for stx_attributes. + Mask: 0, + // Linux just reads inode::i_flags directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) { + return syserror.EPERM + } + } + if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) { + stat, err := fd.impl.Stat(ctx, StatOptions{ + Mask: linux.STATX_UID, + // Linux's inode_owner_or_capable() just reads inode::i_uid + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return err + } + if stat.Mask&linux.STATX_UID == 0 { + return syserror.EPERM + } + if !CanActAsOwner(creds, auth.KUID(stat.UID)) { + return syserror.EPERM + } + } + if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { + return syserror.EINVAL + } + // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) + return nil +} + +// Impl returns the FileDescriptionImpl associated with fd. +func (fd *FileDescription) Impl() FileDescriptionImpl { + return fd.impl +} + // FileDescriptionImpl contains implementation details for an FileDescription. // Implementations of FileDescriptionImpl should contain their associated // FileDescription by value as their first field. @@ -104,14 +203,6 @@ type FileDescriptionImpl interface { // prevent the file descriptor from being closed. OnClose(ctx context.Context) error - // StatusFlags returns file description status flags, as for - // fcntl(F_GETFL). - StatusFlags(ctx context.Context) (uint32, error) - - // SetStatusFlags sets file description status flags, as for - // fcntl(F_SETFL). - SetStatusFlags(ctx context.Context, flags uint32) error - // Stat returns metadata for the file represented by the FileDescription. Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) @@ -185,7 +276,21 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // TODO: extended attributes; file locking + // Listxattr returns all extended attribute names for the file. + Listxattr(ctx context.Context) ([]string, error) + + // Getxattr returns the value associated with the given extended attribute + // for the file. + Getxattr(ctx context.Context, name string) (string, error) + + // Setxattr changes the value associated with the given extended attribute + // for the file. + Setxattr(ctx context.Context, opts SetxattrOptions) error + + // Removexattr removes the given extended attribute from the file. + Removexattr(ctx context.Context, name string) error + + // TODO: file locking } // Dirent holds the information contained in struct linux_dirent64. @@ -214,3 +319,160 @@ type IterDirentsCallback interface { // called. Handle(dirent Dirent) bool } + +// OnClose is called when a file descriptor representing the FileDescription is +// closed. Returning a non-nil error should not prevent the file descriptor +// from being closed. +func (fd *FileDescription) OnClose(ctx context.Context) error { + return fd.impl.OnClose(ctx) +} + +// Stat returns metadata for the file represented by fd. +func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + return fd.impl.Stat(ctx, opts) +} + +// SetStat updates metadata for the file represented by fd. +func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error { + return fd.impl.SetStat(ctx, opts) +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return fd.impl.StatFS(ctx) +} + +// PRead reads from the file represented by fd into dst, starting at the given +// offset, and returns the number of bytes read. PRead is permitted to return +// partial reads with a nil error. +func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return fd.impl.PRead(ctx, dst, offset, opts) +} + +// Read is similar to PRead, but does not specify an offset. +func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return fd.impl.Read(ctx, dst, opts) +} + +// PWrite writes src to the file represented by fd, starting at the given +// offset, and returns the number of bytes written. PWrite is permitted to +// return partial writes with a nil error. +func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return fd.impl.PWrite(ctx, src, offset, opts) +} + +// Write is similar to PWrite, but does not specify an offset. +func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return fd.impl.Write(ctx, src, opts) +} + +// IterDirents invokes cb on each entry in the directory represented by fd. If +// IterDirents has been called since the last call to Seek, it continues +// iteration from the end of the last call. +func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return fd.impl.IterDirents(ctx, cb) +} + +// Seek changes fd's offset (assuming one exists) and returns its new value. +func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return fd.impl.Seek(ctx, offset, whence) +} + +// Sync has the semantics of fsync(2). +func (fd *FileDescription) Sync(ctx context.Context) error { + return fd.impl.Sync(ctx) +} + +// ConfigureMMap mutates opts to implement mmap(2) for the file represented by +// fd. +func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.impl.ConfigureMMap(ctx, opts) +} + +// Ioctl implements the ioctl(2) syscall. +func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.impl.Ioctl(ctx, uio, args) +} + +// Listxattr returns all extended attribute names for the file represented by +// fd. +func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) { + names, err := fd.impl.Listxattr(ctx) + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by default + // don't exist. + return nil, nil + } + return names, err +} + +// Getxattr returns the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) { + return fd.impl.Getxattr(ctx, name) +} + +// Setxattr changes the value associated with the given extended attribute for +// the file represented by fd. +func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return fd.impl.Setxattr(ctx, opts) +} + +// Removexattr removes the given extended attribute from the file represented +// by fd. +func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { + return fd.impl.Removexattr(ctx, name) +} + +// SyncFS instructs the filesystem containing fd to execute the semantics of +// syncfs(2). +func (fd *FileDescription) SyncFS(ctx context.Context) error { + return fd.vd.mount.fs.impl.Sync(ctx) +} + +// MappedName implements memmap.MappingIdentity.MappedName. +func (fd *FileDescription) MappedName(ctx context.Context) string { + vfsroot := RootFromContext(ctx) + s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd) + if vfsroot.Ok() { + vfsroot.DecRef() + } + return s +} + +// DeviceID implements memmap.MappingIdentity.DeviceID. +func (fd *FileDescription) DeviceID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + // There is no STATX_DEV; we assume that Stat will return it if it's + // available regardless of mask. + Mask: 0, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev + // directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil { + return 0 + } + return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor)) +} + +// InodeID implements memmap.MappingIdentity.InodeID. +func (fd *FileDescription) InodeID() uint64 { + stat, err := fd.impl.Stat(context.Background(), StatOptions{ + Mask: linux.STATX_INO, + // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly. + Sync: linux.AT_STATX_DONT_SYNC, + }) + if err != nil || stat.Mask&linux.STATX_INO == 0 { + return 0 + } + return stat.Ino +} + +// Msync implements memmap.MappingIdentity.Msync. +func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error { + return fd.impl.Sync(ctx) +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 4fbad7840..3df49991c 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } +// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// inode_operations::listxattr == NULL in Linux. +func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) { + // This isn't exactly accurate; see FileDescription.Listxattr. + return nil, syserror.ENOTSUP +} + +// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) { + return "", syserror.ENOTSUP +} + +// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { + return syserror.ENOTSUP +} + +// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// inode::i_opflags & IOP_XATTR == 0 in Linux. +func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { + return syserror.ENOTSUP +} + // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. @@ -252,3 +277,12 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6 fd.off = offset return offset, nil } + +// GenericConfigureMMap may be used by most implementations of +// FileDescriptionImpl.ConfigureMMap. +func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error { + opts.Mappable = m + opts.MappingIdentity = fd + fd.IncRef() + return nil +} diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 511b829fc..678be07fe 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -48,7 +48,7 @@ type genCountFD struct { func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription { var fd genCountFD - fd.vfsfd.Init(&fd, mnt, vfsd) + fd.vfsfd.Init(&fd, 0 /* statusFlags */, mnt, vfsd, &FileDescriptionOptions{}) fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd) return &fd.vfsfd } @@ -90,7 +90,7 @@ func TestGenCountFD(t *testing.T) { vfsObj := New() // vfs.New() vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{}) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{}) if err != nil { t.Fatalf("failed to create testfs root mount: %v", err) } @@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) { // The first read causes Generate to be called to fill the FD's buffer. buf := make([]byte, 2) ioseq := usermem.BytesIOSequence(buf) - n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err := fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) { } // A second read without seeking is still at EOF. - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 0 || err != io.EOF { t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err) } // Seeking to the beginning of the file causes it to be regenerated. - n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET) + n, err = fd.Seek(ctx, 0, linux.SEEK_SET) if n != 0 || err != nil { t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err) } - n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{}) + n, err = fd.Read(ctx, ioseq, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err) } @@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) { } // PRead at the beginning of the file also causes it to be regenerated. - n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{}) + n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{}) if n != 1 || (err != nil && err != io.EOF) { t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err) } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 7a074b718..b766614e7 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -33,15 +34,28 @@ type Filesystem struct { // operations. refs int64 + // vfs is the VirtualFilesystem that uses this Filesystem. vfs is + // immutable. + vfs *VirtualFilesystem + // impl is the FilesystemImpl associated with this Filesystem. impl is // immutable. This should be the last field in Dentry. impl FilesystemImpl } // Init must be called before first use of fs. -func (fs *Filesystem) Init(impl FilesystemImpl) { +func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) { fs.refs = 1 + fs.vfs = vfsObj fs.impl = impl + vfsObj.filesystemsMu.Lock() + vfsObj.filesystems[fs] = struct{}{} + vfsObj.filesystemsMu.Unlock() +} + +// VirtualFilesystem returns the containing VirtualFilesystem. +func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem { + return fs.vfs } // Impl returns the FilesystemImpl associated with fs. @@ -49,14 +63,35 @@ func (fs *Filesystem) Impl() FilesystemImpl { return fs.impl } -func (fs *Filesystem) incRef() { +// IncRef increments fs' reference count. +func (fs *Filesystem) IncRef() { if atomic.AddInt64(&fs.refs, 1) <= 1 { - panic("Filesystem.incRef() called without holding a reference") + panic("Filesystem.IncRef() called without holding a reference") + } +} + +// TryIncRef increments fs' reference count and returns true. If fs' reference +// count is zero, TryIncRef does nothing and returns false. +// +// TryIncRef does not require that a reference is held on fs. +func (fs *Filesystem) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&fs.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { + return true + } } } -func (fs *Filesystem) decRef() { +// DecRef decrements fs' reference count. +func (fs *Filesystem) DecRef() { if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { + fs.vfs.filesystemsMu.Lock() + delete(fs.vfs.filesystems, fs) + fs.vfs.filesystemsMu.Unlock() fs.impl.Release() } else if refs < 0 { panic("Filesystem.decRef() called without holding a reference") @@ -151,5 +186,70 @@ type FilesystemImpl interface { // UnlinkAt removes the non-directory file at rp. UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // TODO: d_path(); extended attributes; inotify_add_watch(); bind() + // ListxattrAt returns all extended attribute names for the file at rp. + ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) + + // GetxattrAt returns the value associated with the given extended + // attribute for the file at rp. + GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) + + // SetxattrAt changes the value associated with the given extended + // attribute for the file at rp. + SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + + // RemovexattrAt removes the given extended attribute from the file at rp. + RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + + // PrependPath prepends a path from vd to vd.Mount().Root() to b. + // + // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered + // before vd.Mount().Root(), PrependPath should stop prepending path + // components and return a PrependPathAtVFSRootError. + // + // If traversal of vd.Dentry()'s ancestors encounters an independent + // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a + // descendant of vd.Mount().Root()), PrependPath should stop prepending + // path components and return a PrependPathAtNonMountRootError. + // + // Filesystems for which Dentries do not have meaningful paths may prepend + // an arbitrary descriptive string to b and then return a + // PrependPathSyntheticError. + // + // Most implementations can acquire the appropriate locks to ensure that + // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of + // its ancestors, then call GenericPrependPath. + // + // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. + PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error + + // TODO: inotify_add_watch(); bind() +} + +// PrependPathAtVFSRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter the contextual VFS root. +type PrependPathAtVFSRootError struct{} + +// Error implements error.Error. +func (PrependPathAtVFSRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached VFS root" +} + +// PrependPathAtNonMountRootError is returned by implementations of +// FilesystemImpl.PrependPath() when they encounter an independent ancestor +// Dentry that is not the Mount root. +type PrependPathAtNonMountRootError struct{} + +// Error implements error.Error. +func (PrependPathAtNonMountRootError) Error() string { + return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root" +} + +// PrependPathSyntheticError is returned by implementations of +// FilesystemImpl.PrependPath() for which prepended names do not represent real +// paths. +type PrependPathSyntheticError struct{} + +// Error implements error.Error. +func (PrependPathSyntheticError) Error() string { + return "vfs.FilesystemImpl.PrependPath() prepended synthetic name" } diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go new file mode 100644 index 000000000..7315a588e --- /dev/null +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "strings" + + "gvisor.dev/gvisor/pkg/fspath" +) + +// GenericParseMountOptions parses a comma-separated list of options of the +// form "key" or "key=value", where neither key nor value contain commas, and +// returns it as a map. If str contains duplicate keys, then the last value +// wins. For example: +// +// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'} +// +// GenericParseMountOptions is not appropriate if values may contain commas, +// e.g. in the case of the mpol mount option for tmpfs(5). +func GenericParseMountOptions(str string) map[string]string { + m := make(map[string]string) + for _, opt := range strings.Split(str, ",") { + if len(opt) > 0 { + res := strings.SplitN(opt, "=", 2) + if len(res) == 2 { + m[res[0]] = res[1] + } else { + m[opt] = "" + } + } + } + return m +} + +// GenericPrependPath may be used by implementations of +// FilesystemImpl.PrependPath() for which a single statically-determined lock +// or set of locks is sufficient to ensure its preconditions (as opposed to +// e.g. per-Dentry locks). +// +// Preconditions: Dentry.Name() and Dentry.Parent() must be held constant for +// vd.Dentry() and all of its ancestors. +func GenericPrependPath(vfsroot, vd VirtualDentry, b *fspath.Builder) error { + mnt, d := vd.mount, vd.dentry + for { + if mnt == vfsroot.mount && d == vfsroot.dentry { + return PrependPathAtVFSRootError{} + } + if d == mnt.root { + return nil + } + if d.parent == nil { + return PrependPathAtNonMountRootError{} + } + b.PrependComponent(d.name) + d = d.parent + } +} diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index f401ad7f3..c335e206d 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -25,21 +25,21 @@ import ( // // FilesystemType is analogous to Linux's struct file_system_type. type FilesystemType interface { - // NewFilesystem returns a Filesystem configured by the given options, + // GetFilesystem returns a Filesystem configured by the given options, // along with its mount root. A reference is taken on the returned // Filesystem and Dentry. - NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) + GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) } -// NewFilesystemOptions contains options to FilesystemType.NewFilesystem. -type NewFilesystemOptions struct { +// GetFilesystemOptions contains options to FilesystemType.GetFilesystem. +type GetFilesystemOptions struct { // Data is the string passed as the 5th argument to mount(2), which is // usually a comma-separated list of filesystem-specific mount options. Data string // InternalData holds opaque FilesystemType-specific data. There is // intentionally no way for applications to specify InternalData; if it is - // not nil, the call to NewFilesystem originates from within the sentry. + // not nil, the call to GetFilesystem originates from within the sentry. InternalData interface{} } diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 11702f720..ec23ab0dd 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -38,16 +39,12 @@ import ( // Mount is analogous to Linux's struct mount. (gVisor does not distinguish // between struct mount and struct vfsmount.) type Mount struct { - // The lower 63 bits of refs are a reference count. The MSB of refs is set - // if the Mount has been eagerly unmounted, as by umount(2) without the - // MNT_DETACH flag. refs is accessed using atomic memory operations. - refs int64 - - // The lower 63 bits of writers is the number of calls to - // Mount.CheckBeginWrite() that have not yet been paired with a call to - // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. - // writers is accessed using atomic memory operations. - writers int64 + // vfs, fs, and root are immutable. References are held on fs and root. + // + // Invariant: root belongs to fs. + vfs *VirtualFilesystem + fs *Filesystem + root *Dentry // key is protected by VirtualFilesystem.mountMu and // VirtualFilesystem.mounts.seq, and may be nil. References are held on @@ -57,13 +54,29 @@ type Mount struct { // key.parent.fs. key mountKey - // fs, root, and ns are immutable. References are held on fs and root (but - // not ns). - // - // Invariant: root belongs to fs. - fs *Filesystem - root *Dentry - ns *MountNamespace + // ns is the namespace in which this Mount was mounted. ns is protected by + // VirtualFilesystem.mountMu. + ns *MountNamespace + + // The lower 63 bits of refs are a reference count. The MSB of refs is set + // if the Mount has been eagerly umounted, as by umount(2) without the + // MNT_DETACH flag. refs is accessed using atomic memory operations. + refs int64 + + // children is the set of all Mounts for which Mount.key.parent is this + // Mount. children is protected by VirtualFilesystem.mountMu. + children map[*Mount]struct{} + + // umounted is true if VFS.umountRecursiveLocked() has been called on this + // Mount. VirtualFilesystem does not hold a reference on Mounts for which + // umounted is true. umounted is protected by VirtualFilesystem.mountMu. + umounted bool + + // The lower 63 bits of writers is the number of calls to + // Mount.CheckBeginWrite() that have not yet been paired with a call to + // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect. + // writers is accessed using atomic memory operations. + writers int64 } // A MountNamespace is a collection of Mounts. @@ -73,13 +86,16 @@ type Mount struct { // // MountNamespace is analogous to Linux's struct mnt_namespace. type MountNamespace struct { - refs int64 // accessed using atomic memory operations - // root is the MountNamespace's root mount. root is immutable. root *Mount - // mountpoints contains all Dentries which are mount points in this - // namespace. mountpoints is protected by VirtualFilesystem.mountMu. + // refs is the reference count. refs is accessed using atomic memory + // operations. + refs int64 + + // mountpoints maps all Dentries which are mount points in this namespace + // to the number of Mounts for which they are mount points. mountpoints is + // protected by VirtualFilesystem.mountMu. // // mountpoints is used to determine if a Dentry can be moved or removed // (which requires that the Dentry is not a mount point in the calling @@ -89,26 +105,27 @@ type MountNamespace struct { // MountNamespace; this is required to ensure that // VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate // correctly on unreferenced MountNamespaces. - mountpoints map[*Dentry]struct{} + mountpoints map[*Dentry]uint32 } // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *NewFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return nil, syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts) if err != nil { return nil, err } mntns := &MountNamespace{ refs: 1, - mountpoints: make(map[*Dentry]struct{}), + mountpoints: make(map[*Dentry]uint32), } mntns.root = &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, @@ -117,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } -// NewMount creates and mounts a new Filesystem. -func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *NewFilesystemOptions) error { +// MountAt creates and mounts a Filesystem configured by the given arguments. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { fsType := vfs.getFilesystemType(fsTypeName) if fsType == nil { return syserror.ENODEV } - fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts) + fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return err } @@ -131,17 +148,19 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti // lock ordering. vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{}) if err != nil { - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return err } vfs.mountMu.Lock() + vd.dentry.mu.Lock() for { if vd.dentry.IsDisowned() { + vd.dentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef() - root.decRef(fs) - fs.decRef() + root.DecRef() + fs.DecRef() return syserror.ENOENT } // vd might have been mounted over between vfs.GetDentryAt() and @@ -153,36 +172,272 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti if nextmnt == nil { break } - nextmnt.incRef() - nextmnt.root.incRef(nextmnt.fs) + // It's possible that nextmnt has been umounted but not disconnected, + // in which case vfs no longer holds a reference on it, and the last + // reference may be concurrently dropped even though we're holding + // vfs.mountMu. + if !nextmnt.tryIncMountedRef() { + break + } + // This can't fail since we're holding vfs.mountMu. + nextmnt.root.IncRef() + vd.dentry.mu.Unlock() vd.DecRef() vd = VirtualDentry{ mount: nextmnt, dentry: nextmnt.root, } + vd.dentry.mu.Lock() } // TODO: Linux requires that either both the mount point and the mount root // are directories, or neither are, and returns ENOTDIR if this is not the // case. mntns := vd.mount.ns mnt := &Mount{ + vfs: vfs, fs: fs, root: root, ns: mntns, refs: 1, } - mnt.storeKey(vd.mount, vd.dentry) + vfs.mounts.seq.BeginWrite() + vfs.connectLocked(mnt, vd, mntns) + vfs.mounts.seq.EndWrite() + vd.dentry.mu.Unlock() + vfs.mountMu.Unlock() + return nil +} + +// UmountAt removes the Mount at the given path. +func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error { + if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 { + return syserror.EINVAL + } + + // MNT_FORCE is currently unimplemented except for the permission check. + if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) { + return syserror.EPERM + } + + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return err + } + defer vd.DecRef() + if vd.dentry != vd.mount.root { + return syserror.EINVAL + } + vfs.mountMu.Lock() + if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns { + vfs.mountMu.Unlock() + return syserror.EINVAL + } + + // TODO(jamieliu): Linux special-cases umount of the caller's root, which + // we don't implement yet (we'll just fail it since the caller holds a + // reference on it). + + vfs.mounts.seq.BeginWrite() + if opts.Flags&linux.MNT_DETACH == 0 { + if len(vd.mount.children) != 0 { + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + // We are holding a reference on vd.mount. + expectedRefs := int64(1) + if !vd.mount.umounted { + expectedRefs = 2 + } + if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + return syserror.EBUSY + } + } + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{ + eager: opts.Flags&linux.MNT_DETACH == 0, + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + return nil +} + +type umountRecursiveOptions struct { + // If eager is true, ensure that future calls to Mount.tryIncMountedRef() + // on umounted mounts fail. + // + // eager is analogous to Linux's UMOUNT_SYNC. + eager bool + + // If disconnectHierarchy is true, Mounts that are umounted hierarchically + // should be disconnected from their parents. (Mounts whose parents are not + // umounted, which in most cases means the Mount passed to the initial call + // to umountRecursiveLocked, are unconditionally disconnected for + // consistency with Linux.) + // + // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED. + disconnectHierarchy bool +} + +// umountRecursiveLocked marks mnt and its descendants as umounted. It does not +// release mount or dentry references; instead, it appends VirtualDentries and +// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef +// respectively, and returns updated slices. (This is necessary because +// filesystem locks possibly taken by DentryImpl.DecRef() may precede +// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.) +// +// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree(). +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. +func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) { + if !mnt.umounted { + mnt.umounted = true + mountsToDecRef = append(mountsToDecRef, mnt) + if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) { + vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt)) + } + } + if opts.eager { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs < 0 { + break + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) { + break + } + } + } + for child := range mnt.children { + vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef) + } + return vdsToDecRef, mountsToDecRef +} + +// connectLocked makes vd the mount parent/point for mnt. It consumes +// references held by vd. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. d.mu must be locked. mnt.parent() == nil. +func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { + mnt.storeKey(vd) + if vd.mount.children == nil { + vd.mount.children = make(map[*Mount]struct{}) + } + vd.mount.children[mnt] = struct{}{} atomic.AddUint32(&vd.dentry.mounts, 1) - mntns.mountpoints[vd.dentry] = struct{}{} + mntns.mountpoints[vd.dentry]++ + vfs.mounts.insertSeqed(mnt) vfsmpmounts, ok := vfs.mountpoints[vd.dentry] if !ok { vfsmpmounts = make(map[*Mount]struct{}) vfs.mountpoints[vd.dentry] = vfsmpmounts } vfsmpmounts[mnt] = struct{}{} - vfs.mounts.Insert(mnt) - vfs.mountMu.Unlock() - return nil +} + +// disconnectLocked makes vd have no mount parent/point and returns its old +// mount parent/point with a reference held. +// +// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a +// writer critical section. mnt.parent() != nil. +func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { + vd := mnt.loadKey() + mnt.storeKey(VirtualDentry{}) + delete(vd.mount.children, mnt) + atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 + mnt.ns.mountpoints[vd.dentry]-- + if mnt.ns.mountpoints[vd.dentry] == 0 { + delete(mnt.ns.mountpoints, vd.dentry) + } + vfs.mounts.removeSeqed(mnt) + vfsmpmounts := vfs.mountpoints[vd.dentry] + delete(vfsmpmounts, mnt) + if len(vfsmpmounts) == 0 { + delete(vfs.mountpoints, vd.dentry) + } + return vd +} + +// tryIncMountedRef increments mnt's reference count and returns true. If mnt's +// reference count is already zero, or has been eagerly umounted, +// tryIncMountedRef does nothing and returns false. +// +// tryIncMountedRef does not require that a reference is held on mnt. +func (mnt *Mount) tryIncMountedRef() bool { + for { + refs := atomic.LoadInt64(&mnt.refs) + if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + return false + } + if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + return true + } + } +} + +// IncRef increments mnt's reference count. +func (mnt *Mount) IncRef() { + // In general, negative values for mnt.refs are valid because the MSB is + // the eager-unmount bit. + atomic.AddInt64(&mnt.refs, 1) +} + +// DecRef decrements mnt's reference count. +func (mnt *Mount) DecRef() { + refs := atomic.AddInt64(&mnt.refs, -1) + if refs&^math.MinInt64 == 0 { // mask out MSB + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + mnt.root.DecRef() + mnt.fs.DecRef() + if vd.Ok() { + vd.DecRef() + } + } +} + +// IncRef increments mntns' reference count. +func (mntns *MountNamespace) IncRef() { + if atomic.AddInt64(&mntns.refs, 1) <= 1 { + panic("MountNamespace.IncRef() called without holding a reference") + } +} + +// DecRef decrements mntns' reference count. +func (mntns *MountNamespace) DecRef(vfs *VirtualFilesystem) { + if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { + vfs.mountMu.Lock() + vfs.mounts.seq.BeginWrite() + vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{ + disconnectHierarchy: true, + }, nil, nil) + vfs.mounts.seq.EndWrite() + vfs.mountMu.Unlock() + for _, vd := range vdsToDecRef { + vd.DecRef() + } + for _, mnt := range mountsToDecRef { + mnt.DecRef() + } + } else if refs < 0 { + panic("MountNamespace.DecRef() called without holding a reference") + } } // getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes @@ -223,7 +478,7 @@ retryFirst: // Raced with umount. continue } - mnt.decRef() + mnt.DecRef() mnt = next d = next.root } @@ -231,12 +486,12 @@ retryFirst: } // getMountpointAt returns the mount point for the stack of Mounts including -// mnt. It takes a reference on the returned Mount and Dentry. If no such mount +// mnt. It takes a reference on the returned VirtualDentry. If no such mount // point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil). // // Preconditions: References are held on mnt and root. vfsroot is not (mnt, // mnt.root). -func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) (*Mount, *Dentry) { +func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // // - The caller must have already checked mnt against vfsroot. @@ -246,21 +501,26 @@ func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) // - We don't drop the caller's reference on mnt. retryFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryFirst } if parent == nil { - return nil, nil + return VirtualDentry{} } if !parent.tryIncMountedRef() { // Raced with umount. goto retryFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can only // happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() + goto retryFirst + } + if !vfs.mounts.seq.ReadOk(epoch) { + point.DecRef() + parent.DecRef() goto retryFirst } mnt = parent @@ -274,7 +534,7 @@ retryFirst: } retryNotFirst: epoch := vfs.mounts.seq.BeginRead() - parent, point := mnt.loadKey() + parent, point := mnt.parent(), mnt.point() if !vfs.mounts.seq.ReadOk(epoch) { goto retryNotFirst } @@ -285,59 +545,23 @@ retryFirst: // Raced with umount. goto retryNotFirst } - if !point.tryIncRef(parent.fs) { + if !point.TryIncRef() { // Since Mount holds a reference on Mount.key.point, this can // only happen due to a racing change to Mount.key. - parent.decRef() + parent.DecRef() goto retryNotFirst } if !vfs.mounts.seq.ReadOk(epoch) { - point.decRef(parent.fs) - parent.decRef() + point.DecRef() + parent.DecRef() goto retryNotFirst } - d.decRef(mnt.fs) - mnt.decRef() + d.DecRef() + mnt.DecRef() mnt = parent d = point } - return mnt, d -} - -// tryIncMountedRef increments mnt's reference count and returns true. If mnt's -// reference count is already zero, or has been eagerly unmounted, -// tryIncMountedRef does nothing and returns false. -// -// tryIncMountedRef does not require that a reference is held on mnt. -func (mnt *Mount) tryIncMountedRef() bool { - for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted - return false - } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { - return true - } - } -} - -func (mnt *Mount) incRef() { - // In general, negative values for mnt.refs are valid because the MSB is - // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) -} - -func (mnt *Mount) decRef() { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - parent, point := mnt.loadKey() - if point != nil { - point.decRef(parent.fs) - parent.decRef() - } - mnt.root.decRef(mnt.fs) - mnt.fs.decRef() - } + return VirtualDentry{mnt, d} } // CheckBeginWrite increments the counter of in-progress write operations on @@ -360,7 +584,7 @@ func (mnt *Mount) EndWrite() { atomic.AddInt64(&mnt.writers, -1) } -// Preconditions: VirtualFilesystem.mountMu must be locked for writing. +// Preconditions: VirtualFilesystem.mountMu must be locked. func (mnt *Mount) setReadOnlyLocked(ro bool) error { if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro { return nil @@ -383,22 +607,6 @@ func (mnt *Mount) Filesystem() *Filesystem { return mnt.fs } -// IncRef increments mntns' reference count. -func (mntns *MountNamespace) IncRef() { - if atomic.AddInt64(&mntns.refs, 1) <= 1 { - panic("MountNamespace.IncRef() called without holding a reference") - } -} - -// DecRef decrements mntns' reference count. -func (mntns *MountNamespace) DecRef() { - if refs := atomic.AddInt64(&mntns.refs, 0); refs == 0 { - // TODO: unmount mntns.root - } else if refs < 0 { - panic("MountNamespace.DecRef() called without holding a reference") - } -} - // Root returns mntns' root. A reference is taken on the returned // VirtualDentry. func (mntns *MountNamespace) Root() VirtualDentry { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index f394d7483..adff0b94b 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -37,7 +37,7 @@ func TestMountTableInsertLookup(t *testing.T) { mt.Init() mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) mt.Insert(mount) if m := mt.Lookup(mount.parent(), mount.point()); m != mount { @@ -78,18 +78,10 @@ const enableComparativeBenchmarks = false func newBenchMount() *Mount { mount := &Mount{} - mount.storeKey(&Mount{}, &Dentry{}) + mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) return mount } -func vdkey(mnt *Mount) VirtualDentry { - parent, point := mnt.loadKey() - return VirtualDentry{ - mount: parent, - dentry: point, - } -} - func BenchmarkMountTableParallelLookup(b *testing.B) { for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 { for _, numMounts := range benchNumMounts { @@ -101,7 +93,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { for i := 0; i < numMounts; i++ { mount := newBenchMount() mt.Insert(mount) - keys = append(keys, vdkey(mount)) + keys = append(keys, mount.loadKey()) } var ready sync.WaitGroup @@ -153,7 +145,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms[key] = mount keys = append(keys, key) } @@ -208,7 +200,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := vdkey(mount) + key := mount.loadKey() ms.Store(key, mount) keys = append(keys, key) } @@ -290,7 +282,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -325,7 +317,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { var ms sync.Map for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -379,7 +371,7 @@ func BenchmarkMountMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } } @@ -399,7 +391,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } } @@ -432,13 +424,13 @@ func BenchmarkMountMapRemove(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := range mounts { mount := mounts[i] - ms[vdkey(mount)] = mount + ms[mount.loadKey()] = mount } b.ResetTimer() for i := range mounts { mount := mounts[i] - delete(ms, vdkey(mount)) + delete(ms, mount.loadKey()) } } @@ -454,12 +446,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) { var ms sync.Map for i := range mounts { mount := mounts[i] - ms.Store(vdkey(mount), mount) + ms.Store(mount.loadKey(), mount) } b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Delete(vdkey(mount)) + ms.Delete(mount.loadKey()) } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b0511aa40..ab13fa461 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/third_party/gvsync" + "gvisor.dev/gvisor/pkg/syncutil" ) // mountKey represents the location at which a Mount is mounted. It is @@ -38,16 +38,6 @@ type mountKey struct { point unsafe.Pointer // *Dentry } -// Invariant: mnt.key's fields are nil. parent and point are non-nil. -func (mnt *Mount) storeKey(parent *Mount, point *Dentry) { - atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(parent)) - atomic.StorePointer(&mnt.key.point, unsafe.Pointer(point)) -} - -func (mnt *Mount) loadKey() (*Mount, *Dentry) { - return (*Mount)(atomic.LoadPointer(&mnt.key.parent)), (*Dentry)(atomic.LoadPointer(&mnt.key.point)) -} - func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,6 +46,19 @@ func (mnt *Mount) point() *Dentry { return (*Dentry)(atomic.LoadPointer(&mnt.key.point)) } +func (mnt *Mount) loadKey() VirtualDentry { + return VirtualDentry{ + mount: mnt.parent(), + dentry: mnt.point(), + } +} + +// Invariant: mnt.key.parent == nil. vd.Ok(). +func (mnt *Mount) storeKey(vd VirtualDentry) { + atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) + atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) +} + // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). @@ -72,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq gvsync.SeqCount + seq syncutil.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of @@ -201,9 +204,19 @@ loop: // Insert inserts the given mount into mt. // -// Preconditions: There are no concurrent mutators of mt. mt must not already -// contain a Mount with the same mount point and parent. +// Preconditions: mt must not already contain a Mount with the same mount point +// and parent. func (mt *mountTable) Insert(mount *Mount) { + mt.seq.BeginWrite() + mt.insertSeqed(mount) + mt.seq.EndWrite() +} + +// insertSeqed inserts the given mount into mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must not +// already contain a Mount with the same mount point and parent. +func (mt *mountTable) insertSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) // We're under the maximum load factor if: @@ -215,10 +228,8 @@ func (mt *mountTable) Insert(mount *Mount) { tcap := uintptr(1) << order if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) { // Atomically insert the new element into the table. - mt.seq.BeginWrite() atomic.AddUint64(&mt.size, mtSizeLenOne) mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash) - mt.seq.EndWrite() return } @@ -241,8 +252,6 @@ func (mt *mountTable) Insert(mount *Mount) { for { oldSlot := (*mountSlot)(oldCur) if oldSlot.value != nil { - // Don't need to lock mt.seq yet since newSlots isn't visible - // to readers. mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash) } if oldCur == oldLast { @@ -252,11 +261,9 @@ func (mt *mountTable) Insert(mount *Mount) { } // Insert the new element into the new table. mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash) - // Atomically switch to the new table. - mt.seq.BeginWrite() + // Switch to the new table. atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne) atomic.StorePointer(&mt.slots, newSlots) - mt.seq.EndWrite() } // Preconditions: There are no concurrent mutators of the table (slots, cap). @@ -294,9 +301,18 @@ func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, has // Remove removes the given mount from mt. // -// Preconditions: There are no concurrent mutators of mt. mt must contain -// mount. +// Preconditions: mt must contain mount. func (mt *mountTable) Remove(mount *Mount) { + mt.seq.BeginWrite() + mt.removeSeqed(mount) + mt.seq.EndWrite() +} + +// removeSeqed removes the given mount from mt. +// +// Preconditions: mt.seq must be in a writer critical section. mt must contain +// mount. +func (mt *mountTable) removeSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 @@ -311,7 +327,6 @@ func (mt *mountTable) Remove(mount *Mount) { // backward until we either find an empty slot, or an element that // is already in its first-probed slot. (This is backward shift // deletion.) - mt.seq.BeginWrite() for { nextOff := (off + mountSlotBytes) & offmask nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff)) @@ -330,7 +345,6 @@ func (mt *mountTable) Remove(mount *Mount) { } atomic.StorePointer(&slot.value, nil) atomic.AddUint64(&mt.size, mtSizeLenNegOne) - mt.seq.EndWrite() return } if checkInvariants && slotValue == nil { diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 3aa73d911..97ee4a446 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -46,6 +46,12 @@ type MknodOptions struct { DevMinor uint32 } +// MountOptions contains options to VirtualFilesystem.MountAt(). +type MountOptions struct { + // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). + GetFilesystemOptions GetFilesystemOptions +} + // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). type OpenOptions struct { @@ -95,6 +101,20 @@ type SetStatOptions struct { Stat linux.Statx } +// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), +// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and +// FileDescriptionImpl.Setxattr(). +type SetxattrOptions struct { + // Name is the name of the extended attribute being mutated. + Name string + + // Value is the extended attribute's new value. + Value string + + // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2). + Flags uint32 +} + // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). @@ -114,6 +134,12 @@ type StatOptions struct { Sync uint32 } +// UmountOptions contains options to VirtualFilesystem.UmountAt(). +type UmountOptions struct { + // Flags contains flags as specified for umount2(2). + Flags uint32 +} + // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go new file mode 100644 index 000000000..8e155654f --- /dev/null +++ b/pkg/sentry/vfs/pathname.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/syserror" +) + +var fspathBuilderPool = sync.Pool{ + New: func() interface{} { + return &fspath.Builder{} + }, +} + +func getFSPathBuilder() *fspath.Builder { + return fspathBuilderPool.Get().(*fspath.Builder) +} + +func putFSPathBuilder(b *fspath.Builder) { + // No methods can be called on b after b.String(), so reset it to its zero + // value (as returned by fspathBuilderPool.New) instead. + *b = fspath.Builder{} + fspathBuilderPool.Put(b) +} + +// PathnameWithDeleted returns an absolute pathname to vd, consistent with +// Linux's d_path(). In particular, if vd.Dentry() has been disowned, +// PathnameWithDeleted appends " (deleted)" to the returned pathname. +func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + + origD := vd.dentry +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + // GenericPrependPath() will have returned + // PrependPathAtVFSRootError in this case since it checks + // against vfsroot before mnt.root, but other implementations + // of FilesystemImpl.PrependPath() may return nil instead. + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + // continue loop + case PrependPathSyntheticError: + // Skip prepending "/" and appending " (deleted)". + return b.String(), nil + case PrependPathAtVFSRootError, PrependPathAtNonMountRootError: + break loop + default: + return "", err + } + } + b.PrependByte('/') + if origD.IsDisowned() { + b.AppendString(" (deleted)") + } + return b.String(), nil +} + +// PathnameForGetcwd returns an absolute pathname to vd, consistent with +// Linux's sys_getcwd(). +func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) { + if vd.dentry.IsDisowned() { + return "", syserror.ENOENT + } + + b := getFSPathBuilder() + defer putFSPathBuilder(b) + haveRef := false + defer func() { + if haveRef { + vd.DecRef() + } + }() + unreachable := false +loop: + for { + err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b) + switch err.(type) { + case nil: + if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry { + break loop + } + nextVD := vfs.getMountpointAt(vd.mount, vfsroot) + if !nextVD.Ok() { + unreachable = true + break loop + } + if haveRef { + vd.DecRef() + } + vd = nextVD + haveRef = true + case PrependPathAtVFSRootError: + break loop + case PrependPathAtNonMountRootError, PrependPathSyntheticError: + unreachable = true + break loop + default: + return "", err + } + } + b.PrependByte('/') + if unreachable { + b.PrependString("(unreachable)") + } + return b.String(), nil +} + +// As of this writing, we do not have equivalents to: +// +// - d_absolute_path(), which returns EINVAL if (effectively) any call to +// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError. +// +// - dentry_path(), which does not walk up mounts (and only returns the path +// relative to Filesystem root), but also appends "//deleted" for disowned +// Dentries. +// +// These should be added as necessary. diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f8e74355c..f1edb0680 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { return false } } + +// CheckSetStat checks that creds has permission to change the metadata of a +// file with the given permissions, UID, and GID as specified by stat, subject +// to the rules of Linux's fs/attr.c:setattr_prepare(). +func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error { + if stat.Mask&linux.STATX_MODE != 0 { + if !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + // TODO(b/30815691): "If the calling process is not privileged (Linux: + // does not have the CAP_FSETID capability), and the group of the file + // does not match the effective group ID of the process or one of its + // supplementary group IDs, the S_ISGID bit will be turned off, but + // this will not cause an error to be returned." - chmod(2) + } + if stat.Mask&linux.STATX_UID != 0 { + if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&linux.STATX_GID != 0 { + if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) || + HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) { + return syserror.EPERM + } + } + if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { + if !CanActAsOwner(creds, kuid) { + if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) || + (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) { + return syserror.EPERM + } + // isDir is irrelevant in the following call to + // GenericCheckPermissions since ats == MayWrite means that + // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE + // applies, regardless of isDir. + if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil { + return err + } + } + } + return nil +} + +// CanActAsOwner returns true if creds can act as the owner of a file with the +// given owning UID, consistent with Linux's +// fs/inode.c:inode_owner_or_capable(). +func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool { + if creds.EffectiveKUID == kuid { + return true + } + return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok() +} + +// HasCapabilityOnFile returns true if creds has the given capability with +// respect to a file with the given owning UID and GID, consistent with Linux's +// kernel/capability.c:capable_wrt_inode_uidgid(). +func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool { + return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok() +} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 8d05c8583..d580fd39e 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -85,11 +85,11 @@ func init() { // so error "constants" are really mutable vars, necessitating somewhat // expensive interface object comparisons. -type resolveMountRootError struct{} +type resolveMountRootOrJumpError struct{} // Error implements error.Error. -func (resolveMountRootError) Error() string { - return "resolving mount root" +func (resolveMountRootOrJumpError) Error() string { + return "resolving mount root or jump" } type resolveMountPointError struct{} @@ -149,20 +149,20 @@ func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) { func (rp *ResolvingPath) decRefStartAndMount() { if rp.flags&rpflagsHaveStartRef != 0 { - rp.start.decRef(rp.mount.fs) + rp.start.DecRef() } if rp.flags&rpflagsHaveMountRef != 0 { - rp.mount.decRef() + rp.mount.DecRef() } } func (rp *ResolvingPath) releaseErrorState() { if rp.nextStart != nil { - rp.nextStart.decRef(rp.nextMount.fs) + rp.nextStart.DecRef() rp.nextStart = nil } if rp.nextMount != nil { - rp.nextMount.decRef() + rp.nextMount.DecRef() rp.nextMount = nil } } @@ -269,12 +269,12 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) { parent = d } else if d == rp.mount.root { // At mount root ... - mnt, mntpt := rp.vfs.getMountpointAt(rp.mount, rp.root) - if mnt != nil { + vd := rp.vfs.getMountpointAt(rp.mount, rp.root) + if vd.Ok() { // ... of non-root mount. - rp.nextMount = mnt - rp.nextStart = mntpt - return nil, resolveMountRootError{} + rp.nextMount = vd.mount + rp.nextStart = vd.dentry + return nil, resolveMountRootOrJumpError{} } // ... of root mount. parent = d @@ -385,11 +385,32 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) { } } +// HandleJump is called when the current path component is a "magic" link to +// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem +// method should continue path traversal, HandleMagicSymlink updates the path +// component stream to reflect the magic link target and returns nil. Otherwise +// it returns a non-nil error. +// +// Preconditions: !rp.Done(). +func (rp *ResolvingPath) HandleJump(target VirtualDentry) error { + if rp.symlinks >= linux.MaxSymlinkTraversals { + return syserror.ELOOP + } + rp.symlinks++ + // Consume the path component that represented the magic link. + rp.Advance() + // Unconditionally return a resolveMountRootOrJumpError, even if the Mount + // isn't changing, to force restarting at the new Dentry. + target.IncRef() + rp.nextMount = target.mount + rp.nextStart = target.dentry + return resolveMountRootOrJumpError{} +} + func (rp *ResolvingPath) handleError(err error) bool { switch err.(type) { - case resolveMountRootError: - // Switch to the new Mount. We hold references on the Mount and Dentry - // (from VFS.getMountpointAt()). + case resolveMountRootOrJumpError: + // Switch to the new Mount. We hold references on the Mount and Dentry. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextStart @@ -407,9 +428,8 @@ func (rp *ResolvingPath) handleError(err error) bool { return true case resolveMountPointError: - // Switch to the new Mount. We hold a reference on the Mount (from - // VFS.getMountAt()), but borrow the reference on the mount root from - // the Mount. + // Switch to the new Mount. We hold a reference on the Mount, but + // borrow the reference on the mount root from the Mount. rp.decRefStartAndMount() rp.mount = rp.nextMount rp.start = rp.nextMount.root diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go deleted file mode 100644 index 23f2b9e08..000000000 --- a/pkg/sentry/vfs/syscalls.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vfs - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/syserror" -) - -// PathOperation specifies the path operated on by a VFS method. -// -// PathOperation is passed to VFS methods by pointer to reduce memory copying: -// it's somewhat large and should never escape. (Options structs are passed by -// pointer to VFS and FileDescription methods for the same reason.) -type PathOperation struct { - // Root is the VFS root. References on Root are borrowed from the provider - // of the PathOperation. - // - // Invariants: Root.Ok(). - Root VirtualDentry - - // Start is the starting point for the path traversal. References on Start - // are borrowed from the provider of the PathOperation (i.e. the caller of - // the VFS method to which the PathOperation was passed). - // - // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. - Start VirtualDentry - - // Path is the pathname traversed by this operation. - Pathname string - - // If FollowFinalSymlink is true, and the Dentry traversed by the final - // path component represents a symbolic link, the symbolic link should be - // followed. - FollowFinalSymlink bool -} - -// GetDentryAt returns a VirtualDentry representing the given path, at which a -// file must exist. A reference is taken on the returned VirtualDentry. -func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return VirtualDentry{}, err - } - for { - d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) - if err == nil { - vd := VirtualDentry{ - mount: rp.mount, - dentry: d, - } - rp.mount.incRef() - vfs.putResolvingPath(rp) - return vd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return VirtualDentry{}, err - } - } -} - -// MkdirAt creates a directory at the given path. -func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { - // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is - // also honored." - mkdir(2) - opts.Mode &= 01777 - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return err - } - for { - err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return err - } - } -} - -// OpenAt returns a FileDescription providing access to the file at the given -// path. A reference is taken on the returned FileDescription. -func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { - // Remove: - // - // - O_LARGEFILE, which we always report in FileDescription status flags - // since only 64-bit architectures are supported at this time. - // - // - O_CLOEXEC, which affects file descriptors and therefore must be - // handled outside of VFS. - // - // - Unknown flags. - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE - // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. - if opts.Flags&linux.O_SYNC != 0 { - opts.Flags |= linux.O_DSYNC - } - // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified - // with O_DIRECTORY and a writable access mode (to ensure that it fails on - // filesystem implementations that do not support it). - if opts.Flags&linux.O_TMPFILE != 0 { - if opts.Flags&linux.O_DIRECTORY == 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_CREAT != 0 { - return nil, syserror.EINVAL - } - if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { - return nil, syserror.EINVAL - } - } - // O_PATH causes most other flags to be ignored. - if opts.Flags&linux.O_PATH != 0 { - opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH - } - // "On Linux, the following bits are also honored in mode: [S_ISUID, - // S_ISGID, S_ISVTX]" - open(2) - opts.Mode &= 07777 - - if opts.Flags&linux.O_NOFOLLOW != 0 { - pop.FollowFinalSymlink = false - } - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return nil, err - } - if opts.Flags&linux.O_DIRECTORY != 0 { - rp.mustBeDir = true - rp.mustBeDirOrig = true - } - for { - fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return fd, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return nil, err - } - } -} - -// StatAt returns metadata for the file at the given path. -func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { - rp, err := vfs.getResolvingPath(creds, pop) - if err != nil { - return linux.Statx{}, err - } - for { - stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) - if err == nil { - vfs.putResolvingPath(rp) - return stat, nil - } - if !rp.handleError(err) { - vfs.putResolvingPath(rp) - return linux.Statx{}, err - } - } -} - -// StatusFlags returns file description status flags. -func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) { - flags, err := fd.impl.StatusFlags(ctx) - flags |= linux.O_LARGEFILE - return flags, err -} - -// SetStatusFlags sets file description status flags. -func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error { - return fd.impl.SetStatusFlags(ctx, flags) -} - -// TODO: -// -// - VFS.SyncAllFilesystems() for sync(2) -// -// - Something for syncfs(2) -// -// - VFS.LinkAt() -// -// - VFS.MknodAt() -// -// - VFS.ReadlinkAt() -// -// - VFS.RenameAt() -// -// - VFS.RmdirAt() -// -// - VFS.SetStatAt() -// -// - VFS.StatFSAt() -// -// - VFS.SymlinkAt() -// -// - VFS.UnlinkAt() -// -// - FileDescription.(almost everything) diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go index 70b192ece..d94117bce 100644 --- a/pkg/sentry/vfs/testutil.go +++ b/pkg/sentry/vfs/testutil.go @@ -15,7 +15,10 @@ package vfs import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" @@ -33,10 +36,10 @@ type FDTestFilesystem struct { vfsfs Filesystem } -// NewFilesystem implements FilesystemType.NewFilesystem. -func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) { +// GetFilesystem implements FilesystemType.GetFilesystem. +func (fstype FDTestFilesystemType) GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) { var fs FDTestFilesystem - fs.vfsfs.Init(&fs) + fs.vfsfs.Init(vfsObj, &fs) return &fs.vfsfs, fs.NewDentry(), nil } @@ -114,6 +117,32 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err return syserror.EPERM } +// ListxattrAt implements FilesystemImpl.ListxattrAt. +func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) { + return nil, syserror.EPERM +} + +// GetxattrAt implements FilesystemImpl.GetxattrAt. +func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) { + return "", syserror.EPERM +} + +// SetxattrAt implements FilesystemImpl.SetxattrAt. +func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { + return syserror.EPERM +} + +// RemovexattrAt implements FilesystemImpl.RemovexattrAt. +func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { + return syserror.EPERM +} + +// PrependPath implements FilesystemImpl.PrependPath. +func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error { + b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry))) + return PrependPathSyntheticError{} +} + type fdTestDentry struct { vfsd Dentry } @@ -126,14 +155,14 @@ func (fs *FDTestFilesystem) NewDentry() *Dentry { } // IncRef implements DentryImpl.IncRef. -func (d *fdTestDentry) IncRef(vfsfs *Filesystem) { +func (d *fdTestDentry) IncRef() { } // TryIncRef implements DentryImpl.TryIncRef. -func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool { +func (d *fdTestDentry) TryIncRef() bool { return true } // DecRef implements DentryImpl.DecRef. -func (d *fdTestDentry) DecRef(vfsfs *Filesystem) { +func (d *fdTestDentry) DecRef() { } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 4a8a69540..e60898d7c 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -16,13 +16,24 @@ // // Lock order: // -// Filesystem implementation locks +// FilesystemImpl/FileDescriptionImpl locks // VirtualFilesystem.mountMu +// Dentry.mu +// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry +// VirtualFilesystem.filesystemsMu // VirtualFilesystem.fsTypesMu +// +// Locking Dentry.mu in multiple Dentries requires holding +// VirtualFilesystem.mountMu. package vfs import ( "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" ) // A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts. @@ -33,7 +44,7 @@ type VirtualFilesystem struct { // mountMu serializes mount mutations. // // mountMu is analogous to Linux's namespace_sem. - mountMu sync.RWMutex + mountMu sync.Mutex // mounts maps (mount parent, mount point) pairs to mounts. (Since mounts // are uniquely namespaced, including mount parent in the key correctly @@ -52,7 +63,7 @@ type VirtualFilesystem struct { // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. // - // mountpoints is used to find mounts that must be unmounted due to + // mountpoints is used to find mounts that must be umounted due to // removal of a mount point Dentry from another mount namespace. ("A file // or directory that is a mount point in one namespace that is not a mount // point in another namespace, may be renamed, unlinked, or removed @@ -62,6 +73,11 @@ type VirtualFilesystem struct { // mountpoints is analogous to Linux's mountpoint_hashtable. mountpoints map[*Dentry]map[*Mount]struct{} + // filesystems contains all Filesystems. filesystems is protected by + // filesystemsMu. + filesystemsMu sync.Mutex + filesystems map[*Filesystem]struct{} + // fsTypes contains all FilesystemTypes that are usable in the // VirtualFilesystem. fsTypes is protected by fsTypesMu. fsTypesMu sync.RWMutex @@ -72,12 +88,466 @@ type VirtualFilesystem struct { func New() *VirtualFilesystem { vfs := &VirtualFilesystem{ mountpoints: make(map[*Dentry]map[*Mount]struct{}), + filesystems: make(map[*Filesystem]struct{}), fsTypes: make(map[string]FilesystemType), } vfs.mounts.Init() return vfs } +// PathOperation specifies the path operated on by a VFS method. +// +// PathOperation is passed to VFS methods by pointer to reduce memory copying: +// it's somewhat large and should never escape. (Options structs are passed by +// pointer to VFS and FileDescription methods for the same reason.) +type PathOperation struct { + // Root is the VFS root. References on Root are borrowed from the provider + // of the PathOperation. + // + // Invariants: Root.Ok(). + Root VirtualDentry + + // Start is the starting point for the path traversal. References on Start + // are borrowed from the provider of the PathOperation (i.e. the caller of + // the VFS method to which the PathOperation was passed). + // + // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root. + Start VirtualDentry + + // Path is the pathname traversed by this operation. + Pathname string + + // If FollowFinalSymlink is true, and the Dentry traversed by the final + // path component represents a symbolic link, the symbolic link should be + // followed. + FollowFinalSymlink bool +} + +// GetDentryAt returns a VirtualDentry representing the given path, at which a +// file must exist. A reference is taken on the returned VirtualDentry. +func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return VirtualDentry{}, err + } + for { + d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts) + if err == nil { + vd := VirtualDentry{ + mount: rp.mount, + dentry: d, + } + rp.mount.IncRef() + vfs.putResolvingPath(rp) + return vd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return VirtualDentry{}, err + } + } +} + +// LinkAt creates a hard link at newpop representing the existing file at +// oldpop. +func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// MkdirAt creates a directory at the given path. +func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { + // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is + // also honored." - mkdir(2) + opts.Mode &= 0777 | linux.S_ISVTX + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// MknodAt creates a file of the given mode at the given path. It returns an +// error from the syserror package. +func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil + } + for { + if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + vfs.putResolvingPath(rp) + return nil + } + // Handle mount traversals. + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// OpenAt returns a FileDescription providing access to the file at the given +// path. A reference is taken on the returned FileDescription. +func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + // Remove: + // + // - O_LARGEFILE, which we always report in FileDescription status flags + // since only 64-bit architectures are supported at this time. + // + // - O_CLOEXEC, which affects file descriptors and therefore must be + // handled outside of VFS. + // + // - Unknown flags. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE + // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC. + if opts.Flags&linux.O_SYNC != 0 { + opts.Flags |= linux.O_DSYNC + } + // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified + // with O_DIRECTORY and a writable access mode (to ensure that it fails on + // filesystem implementations that do not support it). + if opts.Flags&linux.O_TMPFILE != 0 { + if opts.Flags&linux.O_DIRECTORY == 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_CREAT != 0 { + return nil, syserror.EINVAL + } + if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY { + return nil, syserror.EINVAL + } + } + // O_PATH causes most other flags to be ignored. + if opts.Flags&linux.O_PATH != 0 { + opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH + } + // "On Linux, the following bits are also honored in mode: [S_ISUID, + // S_ISGID, S_ISVTX]" - open(2) + opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX + + if opts.Flags&linux.O_NOFOLLOW != 0 { + pop.FollowFinalSymlink = false + } + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + if opts.Flags&linux.O_DIRECTORY != 0 { + rp.mustBeDir = true + rp.mustBeDirOrig = true + } + for { + fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return fd, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// ReadlinkAt returns the target of the symbolic link at the given path. +func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return target, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// RenameAt renames the file at oldpop to newpop. +func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error { + oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{}) + if err != nil { + return err + } + rp, err := vfs.getResolvingPath(creds, newpop) + if err != nil { + oldVD.DecRef() + return err + } + for { + err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts) + if err == nil { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + oldVD.DecRef() + vfs.putResolvingPath(rp) + return err + } + } +} + +// RmdirAt removes the directory at the given path. +func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RmdirAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SetStatAt changes metadata for the file at the given path. +func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// StatAt returns metadata for the file at the given path. +func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statx{}, err + } + for { + stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return stat, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statx{}, err + } + } +} + +// StatFSAt returns metadata for the filesystem containing the file at the +// given path. +func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return linux.Statfs{}, err + } + for { + statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return statfs, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return linux.Statfs{}, err + } + } +} + +// SymlinkAt creates a symbolic link at the given path with the given target. +func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// UnlinkAt deletes the non-directory file at the given path. +func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.UnlinkAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// ListxattrAt returns all extended attribute names for the file at the given +// path. +func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil, err + } + for { + names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp) + if err == nil { + vfs.putResolvingPath(rp) + return names, nil + } + if err == syserror.ENOTSUP { + // Linux doesn't actually return ENOTSUP in this case; instead, + // fs/xattr.c:vfs_listxattr() falls back to allowing the security + // subsystem to return security extended attributes, which by + // default don't exist. + vfs.putResolvingPath(rp) + return nil, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return nil, err + } + } +} + +// GetxattrAt returns the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return "", err + } + for { + val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return val, nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return "", err + } + } +} + +// SetxattrAt changes the value associated with the given extended attribute +// for the file at the given path. +func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// RemovexattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return err + } + for { + err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + if err == nil { + vfs.putResolvingPath(rp) + return nil + } + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + +// SyncAllFilesystems has the semantics of Linux's sync(2). +func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + fss := make(map[*Filesystem]struct{}) + vfs.filesystemsMu.Lock() + for fs := range vfs.filesystems { + if !fs.TryIncRef() { + continue + } + fss[fs] = struct{}{} + } + vfs.filesystemsMu.Unlock() + var retErr error + for fs := range fss { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef() + } + return retErr +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). @@ -111,15 +581,15 @@ func (vd VirtualDentry) Ok() bool { // IncRef increments the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) IncRef() { - vd.mount.incRef() - vd.dentry.incRef(vd.mount.fs) + vd.mount.IncRef() + vd.dentry.IncRef() } // DecRef decrements the reference counts on the Mount and Dentry represented // by vd. func (vd VirtualDentry) DecRef() { - vd.dentry.decRef(vd.mount.fs) - vd.mount.decRef() + vd.dentry.DecRef() + vd.mount.DecRef() } // Mount returns the Mount associated with vd. It does not take a reference on diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 145102c0d..5e4611333 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -42,8 +42,35 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) -// DefaultTimeout is a resonable timeout value for most applications. -const DefaultTimeout = 3 * time.Minute +// Opts configures the watchdog. +type Opts struct { + // TaskTimeout is the amount of time to allow a task to execute the + // same syscall without blocking before it's declared stuck. + TaskTimeout time.Duration + + // TaskTimeoutAction indicates what action to take when a stuck tasks + // is detected. + TaskTimeoutAction Action + + // StartupTimeout is the amount of time to allow between watchdog + // creation and calling watchdog.Start. + StartupTimeout time.Duration + + // StartupTimeoutAction indicates what action to take when + // watchdog.Start is not called within the timeout. + StartupTimeoutAction Action +} + +// DefaultOpts is a default set of options for the watchdog. +var DefaultOpts = Opts{ + // Task timeout. + TaskTimeout: 3 * time.Minute, + TaskTimeoutAction: LogWarning, + + // Startup timeout. + StartupTimeout: 30 * time.Second, + StartupTimeoutAction: LogWarning, +} // descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period // is discounted from task's last update time. It's set high enough that small scheduling delays won't @@ -61,6 +88,7 @@ type Action int const ( // LogWarning logs warning message followed by stack trace. LogWarning Action = iota + // Panic will do the same logging as LogWarning and panic(). Panic ) @@ -80,17 +108,13 @@ func (a Action) String() string { // Watchdog is the main watchdog class. It controls a goroutine that periodically // analyses all tasks and reports if any of them appear to be stuck. type Watchdog struct { + // Configuration options are embedded. + Opts + // period indicates how often to check all tasks. It's calculated based on - // 'taskTimeout'. + // opts.TaskTimeout. period time.Duration - // taskTimeout is the amount of time to allow a task to execute the same syscall - // without blocking before it's declared stuck. - taskTimeout time.Duration - - // timeoutAction indicates what action to take when a stuck tasks is detected. - timeoutAction Action - // k is where the tasks come from. k *kernel.Kernel @@ -113,8 +137,12 @@ type Watchdog struct { // mu protects the fields below. mu sync.Mutex - // started is true if the watchdog has been started before. - started bool + // running is true if the watchdog is running. + running bool + + // startCalled is true if Start has ever been called. It remains true + // even if Stop is called. + startCalled bool } type offender struct { @@ -122,58 +150,81 @@ type offender struct { } // New creates a new watchdog. -func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog { - // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much. - period := taskTimeout / 4 - return &Watchdog{ - k: k, - period: period, - taskTimeout: taskTimeout, - timeoutAction: a, - offenders: make(map[*kernel.Task]*offender), - stop: make(chan struct{}), - done: make(chan struct{}), +func New(k *kernel.Kernel, opts Opts) *Watchdog { + // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much. + period := opts.TaskTimeout / 4 + w := &Watchdog{ + Opts: opts, + k: k, + period: period, + offenders: make(map[*kernel.Task]*offender), + stop: make(chan struct{}), + done: make(chan struct{}), + } + + // Handle StartupTimeout if it exists. + if w.StartupTimeout > 0 { + log.Infof("Watchdog waiting %v for startup", w.StartupTimeout) + go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore. } + + return w } // Start starts the watchdog. func (w *Watchdog) Start() { - if w.taskTimeout == 0 { - log.Infof("Watchdog disabled") - return - } - w.mu.Lock() defer w.mu.Unlock() - if w.started { + w.startCalled = true + + if w.running { return } + if w.TaskTimeout == 0 { + log.Infof("Watchdog task timeout disabled") + return + } w.lastRun = w.k.MonotonicClock().Now() - log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction) + log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction) go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore. - w.started = true + w.running = true } // Stop requests the watchdog to stop and wait for it. func (w *Watchdog) Stop() { - if w.taskTimeout == 0 { + if w.TaskTimeout == 0 { return } w.mu.Lock() defer w.mu.Unlock() - if !w.started { + if !w.running { return } log.Infof("Stopping watchdog") w.stop <- struct{}{} <-w.done - w.started = false + w.running = false log.Infof("Watchdog stopped") } +// waitForStart waits for Start to be called and takes action if it does not +// happen within the startup timeout. +func (w *Watchdog) waitForStart() { + <-time.After(w.StartupTimeout) + w.mu.Lock() + defer w.mu.Unlock() + if w.startCalled { + // We are fine. + return + } + var buf bytes.Buffer + buf.WriteString("Watchdog.Start() not called within %s:\n") + w.doAction(w.StartupTimeoutAction, false, &buf) +} + // loop is the main watchdog routine. It only returns when 'Stop()' is called. func (w *Watchdog) loop() { // Loop until someone stops it. @@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() { select { case <-done: - case <-time.After(w.taskTimeout): + case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. // No one is wathching the watchdog watcher though. w.reportStuckWatchdog() @@ -231,12 +282,14 @@ func (w *Watchdog) runTurn() { if tsched.State == kernel.TaskGoroutineRunningSys { lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick))) elapsed := now.Sub(lastUpdateTime) - discount - if elapsed > w.taskTimeout { + if elapsed > w.TaskTimeout { tc, ok := w.offenders[t] if !ok { // New stuck task detected. // - // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel. + // Note that tasks blocked doing IO may be considered stuck in kernel, + // unless they are surrounded b + // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() newTaskFound = true @@ -261,28 +314,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - w.onStuckTask(newTaskFound, &buf) + + // Dump stack only if a new task is detected or if it sometime has + // passed since the last time a stack dump was generated. + skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, skipStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer buf.WriteString("Watchdog goroutine is stuck:\n") - w.onStuckTask(true, &buf) + w.doAction(w.TaskTimeoutAction, false, &buf) } -func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { - switch w.timeoutAction { +// doAction will take the given action. If the action is LogWarnind and +// skipStack is true, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { + switch action { case LogWarning: - // Dump stack only if a new task is detected or if it sometime has passed since - // the last time a stack dump was generated. - if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { + if skipStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) - } else { - log.TracebackAll(msg.String()) - w.lastStackDump = time.Now() + return + } + log.TracebackAll(msg.String()) + w.lastStackDump = time.Now() case Panic: // Panic will skip over running tasks, which is likely the culprit here. So manually @@ -301,5 +360,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + default: + panic(fmt.Sprintf("Unknown watchdog action %v", action)) + } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 8f5e60a25..acbf0229b 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 47e6b878a..590c241a3 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState { // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { + // ctx is the decode context. + ctx context.Context + // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 5d9409a45..c5118d3a9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -16,6 +16,7 @@ package state import ( "container/list" + "context" "encoding/binary" "fmt" "io" @@ -38,6 +39,9 @@ type queuedObject struct { // The encoding process is a breadth-first traversal of the object graph. The // inherent races and dependencies are much simpler than the decode case. type encodeState struct { + // ctx is the encode context. + ctx context.Context + // lastID is the last object ID. // // See idsByObject for context. Because of the special zero encoding diff --git a/pkg/state/map.go b/pkg/state/map.go index 7e6fefed4..4f3ebb0da 100644 --- a/pkg/state/map.go +++ b/pkg/state/map.go @@ -15,6 +15,7 @@ package state import ( + "context" "fmt" "reflect" "sort" @@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) { // data dependencies have been cleared. m.os.callbacks = append(m.os.callbacks, fn) } + +// Context returns the current context object. +func (m Map) Context() context.Context { + if m.es != nil { + return m.es.ctx + } else if m.ds != nil { + return m.ds.ctx + } + return context.Background() // No context. +} diff --git a/pkg/state/object.proto b/pkg/state/object.proto index 952289069..5ebcfb151 100644 --- a/pkg/state/object.proto +++ b/pkg/state/object.proto @@ -18,8 +18,8 @@ package gvisor.state.statefile; // Slice is a slice value. message Slice { - uint32 length = 1; - uint32 capacity = 2; + uint32 length = 1; + uint32 capacity = 2; uint64 ref_value = 3; } @@ -30,13 +30,13 @@ message Array { // Map is a map value. message Map { - repeated Object keys = 1; + repeated Object keys = 1; repeated Object values = 2; } // Interface is an interface value. message Interface { - string type = 1; + string type = 1; Object value = 2; } @@ -47,7 +47,7 @@ message Struct { // Field encodes a single field. message Field { - string name = 1; + string name = 1; Object value = 2; } @@ -113,28 +113,28 @@ message Float32s { // Note that ref_value references an Object.id, below. message Object { oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; + bool bool_value = 1; + bytes string_value = 2; + int64 int64_value = 3; + uint64 uint64_value = 4; + double double_value = 5; + uint64 ref_value = 6; + Slice slice_value = 7; + Array array_value = 8; + Interface interface_value = 9; + Struct struct_value = 10; + Map map_value = 11; + bytes byte_array_value = 12; + Uint16s uint16_array_value = 13; + Uint32s uint32_array_value = 14; + Uint64s uint64_array_value = 15; + Uintptrs uintptr_array_value = 16; + Int8s int8_array_value = 17; + Int16s int16_array_value = 18; + Int32s int32_array_value = 19; + Int64s int64_array_value = 20; + Bools bool_array_value = 21; + Float64s float64_array_value = 22; + Float32s float32_array_value = 23; } } diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..dbe507ab4 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 7c24bbcda..d7221e9e8 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "io/ioutil" "math" "reflect" @@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) { saveBuffer := &bytes.Buffer{} saveObjectPtr := reflect.New(reflect.TypeOf(root)) saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Save failed unexpectedly: %v", err) continue } else if err != nil { @@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) { // Load a new copy of the object. loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Load failed unexpectedly: %v", err) continue } else if err != nil { @@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) { bs := buildObject(b.N) var stats Stats b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { + if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { b.Errorf("save failed: %v", err) } b.StopTimer() @@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) { bs := buildObject(b.N) var newBS benchStruct buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { + if err := Save(context.Background(), buf, bs, nil); err != nil { b.Errorf("save failed: %v", err) } var stats Stats b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { + if err := Load(context.Background(), buf, &newBS, &stats); err != nil { b.Errorf("load failed: %v", err) } b.StopTimer() diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD new file mode 100644 index 000000000..cb1f41628 --- /dev/null +++ b/pkg/syncutil/BUILD @@ -0,0 +1,52 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "syncutil", + srcs = [ + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/syncutil", +) + +go_test( + name = "syncutil_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":syncutil"], +) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/syncutil/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/syncutil/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/syncutil/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD new file mode 100644 index 000000000..63f411a90 --- /dev/null +++ b/pkg/syncutil/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/syncutil:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/syncutil/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go new file mode 100644 index 000000000..ffaf7ecc7 --- /dev/null +++ b/pkg/syncutil/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package syncutil + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..51e11555d --- /dev/null +++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package syncutil + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go new file mode 100644 index 000000000..348675baa --- /dev/null +++ b/pkg/syncutil/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package syncutil + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go new file mode 100644 index 000000000..0a0a9deda --- /dev/null +++ b/pkg/syncutil/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package syncutil + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go new file mode 100644 index 000000000..206067ec1 --- /dev/null +++ b/pkg/syncutil/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package syncutil + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go new file mode 100644 index 000000000..cb6d2eb22 --- /dev/null +++ b/pkg/syncutil/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/syncutil" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if syncutil.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if syncutil.RaceEnabled { + syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD new file mode 100644 index 000000000..ba18f3238 --- /dev/null +++ b/pkg/syncutil/seqatomictest/BUILD @@ -0,0 +1,35 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/syncutil:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", + deps = [ + "//pkg/syncutil", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = [ + "//pkg/syncutil", + ], +) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..b0db44999 --- /dev/null +++ b/pkg/syncutil/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/syncutil" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq syncutil.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq syncutil.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq syncutil.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq syncutil.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq syncutil.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go new file mode 100644 index 000000000..11d8dbfaa --- /dev/null +++ b/pkg/syncutil/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncutil + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go new file mode 100644 index 000000000..14d6aedea --- /dev/null +++ b/pkg/syncutil/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncutil + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go new file mode 100644 index 000000000..66e750d06 --- /dev/null +++ b/pkg/syncutil/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package syncutil provides synchronization primitives. +package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3c2b2b5ea..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 4287464f3..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable { return Prependable{buf: v, usedIdx: 0} } +// NewEmptyPrependableFromView creates a new prependable buffer from a View. +func NewEmptyPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: len(v)} +} + // View returns a View of the backing buffer that contains all prepended // data so far. func (p Prependable) View() View { @@ -72,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index 4cecfb989..b6fa6fc37 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 02137e1c9..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker { // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and // potentially additional ICMPv6 header fields. +// +// ICMPv6 will validate the checksum field before calling checkers. func ICMPv6(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) + if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) + } + for _, f := range checkers { f(t, icmp) } diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 0c5c20cea..e648efa71 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -7,9 +7,7 @@ go_library( name = "jenkins", srcs = ["jenkins.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 07d09abed..f1d837196 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -19,6 +19,7 @@ go_library( "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", + "ndp_router_advert.go", "tcp.go", "udp.go", ], @@ -36,16 +37,29 @@ go_test( name = "header_x_test", size = "small", srcs = [ + "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], - deps = [":header"], + deps = [ + ":header", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) go_test( name = "header_test", size = "small", - srcs = ["ndp_test.go"], + srcs = [ + "eth_test.go", + "ndp_test.go", + ], embed = [":header"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 39a4d69be..9749c7f4d 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -23,11 +23,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -func calculateChecksum(buf []byte, initial uint32) uint16 { +func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + l := len(buf) - if l&1 != 0 { + odd = l&1 != 0 + if odd { l-- v += uint32(buf[l]) << 8 } @@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the @@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in @@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - var odd bool + return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) +} + +// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the +// bytes in the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { + odd := false sum := initial for _, v := range vv.Views() { if len(v) == 0 { continue } - s := uint32(sum) - if odd { - s += uint32(v[0]) - v = v[1:] + + if off >= len(v) { + off -= len(v) + continue + } + v = v[off:] + + l := len(v) + if l > size { + l = size + } + v = v[:l] + + sum, odd = calculateChecksum(v, odd, uint32(sum)) + + size -= len(v) + if size == 0 { + break } - odd = len(v)&1 != 0 - sum = calculateChecksum(v, s) + off = 0 } return sum } diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go new file mode 100644 index 000000000..86b466c1c --- /dev/null +++ b/pkg/tcpip/header/checksum_test.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestChecksumVVWithOffset(t *testing.T) { + testCases := []struct { + name string + vv buffer.VectorisedView + off, size int + initial uint16 + want uint16 + }{ + { + name: "empty", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 0, + want: 0, + }, + { + name: "OneView", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 5, + want: 1294, + }, + { + name: "TwoViews", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 0, + size: 11, + want: 33819, + }, + { + name: "TwoViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 1, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 7, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithInitial", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), + }), + initial: 77, + off: 7, + size: 11, + want: 33896, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { + t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + } + v := tc.vv.ToView() + v.TrimFront(tc.off) + v.CapLength(tc.size) + if got, want := header.Checksum(v, tc.initial), tc.want; got != want { + t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + } + }) + } +} diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 4c3d3311f..f5d2c127f 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -48,8 +48,48 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 + + // unspecifiedEthernetAddress is the unspecified ethernet address + // (all bits set to 0). + unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + + // unicastMulticastFlagMask is the mask of the least significant bit in + // the first octet (in network byte order) of an ethernet address that + // determines whether the ethernet address is a unicast or multicast. If + // the masked bit is a 1, then the address is a multicast, unicast + // otherwise. + // + // See the IEEE Std 802-2001 document for more details. Specifically, + // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf: + // "A 48-bit universal address consists of two parts. The first 24 bits + // correspond to the OUI as assigned by the IEEE, expect that the + // assignee may set the LSB of the first octet to 1 for group addresses + // or set it to 0 for individual addresses." + unicastMulticastFlagMask = 1 + + // unicastMulticastFlagByteIdx is the byte that holds the + // unicast/multicast flag. See unicastMulticastFlagMask. + unicastMulticastFlagByteIdx = 0 +) + +const ( + // EthernetProtocolAll is a catch-all for all protocols carried inside + // an ethernet frame. It is mainly used to create packet sockets that + // capture all traffic. + EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003 + + // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype. + EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200 ) +// Ethertypes holds the protocol numbers describing the payload of an ethernet +// frame. These types aren't necessarily supported by netstack, but can be used +// to catch all traffic of a type via packet endpoints. +var Ethertypes = []tcpip.NetworkProtocolNumber{ + EthernetProtocolAll, + EthernetProtocolPUP, +} + // SourceAddress returns the "MAC source" field of the ethernet frame header. func (b Ethernet) SourceAddress() tcpip.LinkAddress { return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) @@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } + +// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// ethernet address. +func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + // Must be of the right length. + if len(addr) != EthernetAddressSize { + return false + } + + // Must not be unspecified. + if addr == unspecifiedEthernetAddress { + return false + } + + // Must not be a multicast. + if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { + return false + } + + // addr is a valid unicast ethernet address. + return true +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go new file mode 100644 index 000000000..6634c90f5 --- /dev/null +++ b/pkg/tcpip/header/eth_test.go @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestIsValidUnicastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Valid", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index e51c5098c..b4037b6c8 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -80,6 +80,13 @@ const ( // icmpv6SequenceOffset is the offset of the sequence field // in a ICMPv6 Echo Request/Reply message. icmpv6SequenceOffset = 6 + + // NDPHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + NDPHopLimit = 255 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -125,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } -// SetChecksum calculates and sets the ICMP checksum field. +// SetChecksum sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } @@ -190,7 +197,7 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } -// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { // Calculate the IPv6 pseudo-header upper-layer checksum. diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index b125bbea5..fc671e439 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -90,9 +90,23 @@ const ( // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { @@ -101,6 +115,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined +// by RFC 4291 section 2.5.6. +// +// The prefix is fe80::/64 +var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{ + Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 64, +} + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { @@ -255,27 +278,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFF + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..42c5c6fc1 --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go index 5c2b472c8..505c92668 100644 --- a/pkg/tcpip/header/ndp_neighbor_advert.go +++ b/pkg/tcpip/header/ndp_neighbor_advert.go @@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip" // NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will // only contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.4 for more details. type NDPNeighborAdvert []byte const ( diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go index 1dcb0fbc6..3a1b8e139 100644 --- a/pkg/tcpip/header/ndp_neighbor_solicit.go +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go @@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip" // NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only // contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.3 for more details. type NDPNeighborSolicit []byte const ( diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index b28bde15b..06e0bace2 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,6 +15,11 @@ package header import ( + "encoding/binary" + "errors" + "math" + "time" + "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,15 +32,226 @@ const ( // Link Layer Option for an Ethernet address. ndpTargetEthernetLinkLayerAddressSize = 8 + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType = 3 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = (1 << 7) + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 ) +var ( + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. + // + // This is a variable instead of a constant so that tests can change + // this value to a smaller value. It should only be modified by tests. + NDPInfiniteLifetime = time.Second * math.MaxUint32 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + // The NDPOptions this NDPOptionIterator is iterating over. + opts NDPOptions +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") + ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if len(i.opts) == 0 { + return nil, true, nil + } + + // Do we have enough bytes for an NDP option that has a Length + // field of at least 1? Note, 0 in the Length field is invalid. + if len(i.opts) < lengthByteUnits { + return nil, true, ErrNDPOptBufExhausted + } + + // Get the Type field. + t := i.opts[0] + + // Get the Length field. + l := i.opts[1] + + // This would indicate an erroneous NDP option as the Length + // field should never be 0. + if l == 0 { + return nil, true, ErrNDPOptZeroLength + } + + // How many bytes are in the option body? + numBytes := int(l) * lengthByteUnits + numBodyBytes := numBytes - 2 + + potentialBody := i.opts[2:] + + // This would indicate an erroenous NDPOptions buffer as we ran + // out of the buffer in the middle of an NDP option. + if left := len(potentialBody); left < numBodyBytes { + return nil, true, ErrNDPOptBufExhausted + } + + // Get only the options body, leaving the rest of the options + // buffer alone. + body := potentialBody[:numBodyBytes] + + // Update opts with the remaining options body. + i.opts = i.opts[numBytes:] + + switch t { + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, ErrNDPOptMalformedBody + } + + return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + // NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. type NDPOptions []byte +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{opts: b} + + if check { + for it2 := it; true; { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + // Serialize serializes the provided list of NDP options into o. // // Note, b must be of sufficient size to hold all the options in s. See @@ -75,15 +291,15 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { return done } -// ndpOption is the set of functions to be implemented by all NDP option types. -type ndpOption interface { - // Type returns the type of this ndpOption. +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + // Type returns the type of the receiver. Type() uint8 - // Length returns the length of the body of this ndpOption, in bytes. + // Length returns the length of the body of the receiver, in bytes. Length() int - // serializeInto serializes this ndpOption into the provided byte + // serializeInto serializes the receiver into the provided byte // buffer. // // Note, the caller MUST provide a byte buffer with size of at least @@ -92,15 +308,15 @@ type ndpOption interface { // buffer is not of sufficient size. // // serializeInto will return the number of bytes that was used to - // serialize this ndpOption. Implementers must only use the number of - // bytes required to serialize this ndpOption. Callers MAY provide a + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a // larger buffer than required to serialize into. serializeInto([]byte) int } // paddedLength returns the length of o, in bytes, with any padding bytes, if // required. -func paddedLength(o ndpOption) int { +func paddedLength(o NDPOption) int { l := o.Length() if l == 0 { @@ -139,7 +355,7 @@ func paddedLength(o ndpOption) int { } // NDPOptionsSerializer is a serializer for NDP options. -type NDPOptionsSerializer []ndpOption +type NDPOptionsSerializer []NDPOption // Length returns the total number of bytes required to serialize. func (b NDPOptionsSerializer) Length() int { @@ -154,19 +370,220 @@ func (b NDPOptionsSerializer) Length() int { // NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option // as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. type NDPTargetLinkLayerAddressOption tcpip.LinkAddress -// Type implements ndpOption.Type. +// Type implements NDPOption.Type. func (o NDPTargetLinkLayerAddressOption) Type() uint8 { return NDPTargetLinkLayerAddressOptionType } -// Length implements ndpOption.Length. +// Length implements NDPOption.Length. func (o NDPTargetLinkLayerAddressOption) Length() int { return len(o) } -// serializeInto implements ndpOption.serializeInto. +// serializeInto implements NDPOption.serializeInto. func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() uint8 { + return NDPPrefixInformationType +} + +// Length implements NDPOption.Length. +func (o NDPPrefixInformation) Length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + for i := range reserved2 { + reserved2[i] = 0 + } + + return used +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go new file mode 100644 index 000000000..bf7610863 --- /dev/null +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -0,0 +1,112 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "time" +) + +// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.2 for more details. +type NDPRouterAdvert []byte + +const ( + // NDPRAMinimumSize is the minimum size of a valid NDP Router + // Advertisement message (body of an ICMPv6 packet). + NDPRAMinimumSize = 12 + + // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field + // within an NDPRouterAdvert. + ndpRACurrHopLimitOffset = 0 + + // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags + // within an NDPRouterAdvert. + ndpRAFlagsOffset = 1 + + // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address + // Configuration flag within the bit-field/flags byte of an + // NDPRouterAdvert. + ndpRAManagedAddrConfFlagMask = (1 << 7) + + // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag + // within the bit-field/flags byte of an NDPRouterAdvert. + ndpRAOtherConfFlagMask = (1 << 6) + + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime + // field within an NDPRouterAdvert. + ndpRARouterLifetimeOffset = 2 + + // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time + // field within an NDPRouterAdvert. + ndpRAReachableTimeOffset = 4 + + // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer + // field within an NDPRouterAdvert. + ndpRARetransTimerOffset = 8 + + // ndpRAOptionsOffset is the start of the NDP options in an + // NDPRouterAdvert. + ndpRAOptionsOffset = 12 +) + +// CurrHopLimit returns the value of the Curr Hop Limit field. +func (b NDPRouterAdvert) CurrHopLimit() uint8 { + return b[ndpRACurrHopLimitOffset] +} + +// ManagedAddrConfFlag returns the value of the Managed Address Configuration +// flag. +func (b NDPRouterAdvert) ManagedAddrConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0 +} + +// OtherConfFlag returns the value of the Other Configuration flag. +func (b NDPRouterAdvert) OtherConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 +} + +// RouterLifetime returns the lifetime associated with the default router. A +// value of 0 means the source of the Router Advertisement is not a default +// router and SHOULD NOT appear on the default router list. Note, a value of 0 +// only means that the router should not be used as a default router, it does +// not apply to other information contained in the Router Advertisement. +func (b NDPRouterAdvert) RouterLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.2. + return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:])) +} + +// ReachableTime returns the time that a node assumes a neighbor is reachable +// after having received a reachability confirmation. A value of 0 means +// that it is unspecified by the source of the Router Advertisement message. +func (b NDPRouterAdvert) ReachableTime() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:])) +} + +// RetransTimer returns the time between retransmitted Neighbor Solicitation +// messages. A value of 0 means that it is unspecified by the source of the +// Router Advertisement message. +func (b NDPRouterAdvert) RetransTimer() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:])) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterAdvert) Options() NDPOptions { + return NDPOptions(b[ndpRAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index a431a6e61..2c439d70c 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -17,7 +17,9 @@ package header import ( "bytes" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -35,18 +37,18 @@ func TestNDPNeighborSolicit(t *testing.T) { ns := NDPNeighborSolicit(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := ns.TargetAddress(); got != addr { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { - t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2) + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } } @@ -64,59 +66,129 @@ func TestNDPNeighborAdvert(t *testing.T) { na := NDPNeighborAdvert(b) addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") if got := na.TargetAddress(); got != addr { - t.Fatalf("got TargetAddress = %s, want %s", got, addr) + t.Errorf("got TargetAddress = %s, want %s", got, addr) } // Test getting the Router Flag. if got := na.RouterFlag(); !got { - t.Fatalf("got RouterFlag = false, want = true") + t.Errorf("got RouterFlag = false, want = true") } // Test getting the Solicited Flag. if got := na.SolicitedFlag(); got { - t.Fatalf("got SolicitedFlag = true, want = false") + t.Errorf("got SolicitedFlag = true, want = false") } // Test getting the Override Flag. if got := na.OverrideFlag(); !got { - t.Fatalf("got OverrideFlag = false, want = true") + t.Errorf("got OverrideFlag = false, want = true") } // Test updating the Target Address. addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { - t.Fatalf("got TargetAddress = %s, want %s", got, addr2) + t.Errorf("got TargetAddress = %s, want %s", got, addr2) } // Make sure the address got updated in the backing buffer. if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { - t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2) + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) } // Test updating the Router Flag. na.SetRouterFlag(false) if got := na.RouterFlag(); got { - t.Fatalf("got RouterFlag = true, want = false") + t.Errorf("got RouterFlag = true, want = false") } // Test updating the Solicited Flag. na.SetSolicitedFlag(true) if got := na.SolicitedFlag(); !got { - t.Fatalf("got SolicitedFlag = false, want = true") + t.Errorf("got SolicitedFlag = false, want = true") } // Test updating the Override Flag. na.SetOverrideFlag(false) if got := na.OverrideFlag(); got { - t.Fatalf("got OverrideFlag = true, want = false") + t.Errorf("got OverrideFlag = true, want = false") } // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Fatalf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64") } } +func TestNDPRouterAdvert(t *testing.T) { + b := []byte{ + 64, 128, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + + ra := NDPRouterAdvert(b) + + if got := ra.CurrHopLimit(); got != 64 { + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) + } + + if got := ra.ManagedAddrConfFlag(); !got { + t.Errorf("got ManagedAddrConfFlag = false, want = true") + } + + if got := ra.OtherConfFlag(); got { + t.Errorf("got OtherConfFlag = true, want = false") + } + + if got, want := ra.RouterLifetime(), time.Second*258; got != want { + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + } +} + +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } + +} + // TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a // NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { @@ -159,6 +231,555 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { if !bytes.Equal(test.buf, test.expectedBuf) { t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPPrefixInformationOption tests the field getters and serialization of a +// NDPPrefixInformation. +func TestNDPPrefixInformationOption(t *testing.T) { + b := []byte{ + 43, 127, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 5, 5, 5, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPPrefixInformation(b), + } + opts.Serialize(serializer) + expectedBuf := []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + if !bytes.Equal(targetBuf, expectedBuf) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) + + if got := pi.Type(); got != 3 { + t.Errorf("got Type = %d, want = 3", got) + } + + if got := pi.Length(); got != 30 { + t.Errorf("got Length = %d, want = 30", got) + } + + if got := pi.PrefixLength(); got != 43 { + t.Errorf("got PrefixLength = %d, want = 43", got) + } + + if pi.OnLinkFlag() { + t.Error("got OnLinkFlag = true, want = false") + } + + if !pi.AutonomousAddressConfigurationFlag() { + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") + } + + if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { + t.Errorf("got ValidLifetime = %d, want = %d", got, want) + } + + if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) + } + + if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } }) } } + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expected error + }{ + { + "ZeroLengthField", + []byte{0, 0, 0, 0, 0, 0, 0, 0}, + ErrNDPOptZeroLength, + }, + { + "ValidTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { + "ValidPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "TooSmallPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + ErrNDPOptBufExhausted, + }, + { + "InvalidPrefixInformationLength", + []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + ErrNDPOptMalformedBody, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformation", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); err != test.expected { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + }) + } +} + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Taret Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 97a794986..7dbc05754 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -6,7 +6,7 @@ go_library( name = "channel", srcs = ["channel.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 18adb2085..70188551f 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -25,10 +25,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt tcpip.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO } // Endpoint is link layer endpoint that stores outbound packets in a channel @@ -65,14 +64,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil)) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -96,7 +95,7 @@ func (e *Endpoint) MTU() uint32 { func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { caps := stack.LinkEndpointCapabilities(0) if e.GSO { - caps |= stack.CapabilityGSO + caps |= stack.CapabilityHardwareGSO } return caps } @@ -118,12 +117,55 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + payloadView := pkts[0].Data.ToView() + n := 0 +packetLoop: + for _, pkt := range pkts { + off := pkt.DataOffset + size := pkt.DataSize + p := PacketInfo{ + Pkt: tcpip.PacketBuffer{ + Header: pkt.Header, + Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), + }, + Proto: protocol, + GSO: gso, + } + + select { + case e.C <- p: + n++ + default: + break packetLoop + } + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := PacketInfo{ + Pkt: tcpip.PacketBuffer{Data: vv}, + Proto: 0, + GSO: nil, } select { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 8fa9e3984..897c94821 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,9 +14,7 @@ go_library( "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f80ac3435..fa8a703d9 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -165,6 +165,9 @@ type Options struct { // disabled. GSOMaxSize uint32 + // SoftwareGSOEnabled indicates whether software GSO is enabled or not. + SoftwareGSOEnabled bool + // PacketDispatchMode specifies the type of inbound dispatcher to be // used for this endpoint. PacketDispatchMode PacketDispatchMode @@ -242,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) { } if isSocket { if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO + if opts.SoftwareGSOEnabled { + e.caps |= stack.CapabilitySoftwareGSO + } else { + e.caps |= stack.CapabilityHardwareGSO + } e.gsoMaxSize = opts.GSOMaxSize } } @@ -379,10 +386,11 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -397,17 +405,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen eth.Encode(ethHdr) } - if e.Capabilities()&stack.CapabilityGSO != 0 { + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -420,18 +428,151 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + } + + if pkt.Data.Size() == 0 { + return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) +} + +// WritePackets writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var ethHdrBuf []byte + // hdr + data + iovLen := 2 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) + n := len(pkts) + + views := pkts[0].Data.Views() + /* + * Each bondary in views can add one more iovec. + * + * payload | | | | + * ----------------------------- + * packets | | | | | | | + * ----------------------------- + * iovecs | | | | | | | | | + */ + iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) + mmsgHdrs := make([]rawfile.MMsgHdr, n) + + iovecIdx := 0 + viewIdx := 0 + viewOff := 0 + off := 0 + nextOff := 0 + for i := range pkts { + // TODO(b/134618279): Different packets may have different data + // in the future. We should handle this. + if !viewsEqual(pkts[i].Data.Views(), views) { + panic("All packets in pkts should have the same Data.") + } + + prevIovecIdx := iovecIdx + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovec[iovecIdx] + packetSize := pkts[i].DataSize + hdr := &pkts[i].Header + + off = pkts[i].DataOffset + if off != nextOff { + // We stop in a different point last time. + size := packetSize + viewIdx = 0 + viewOff = 0 + for size > 0 { + if size >= len(views[viewIdx]) { + viewIdx++ + viewOff = 0 + size -= len(views[viewIdx]) + } else { + viewOff = size + size = 0 + } + } + } + nextOff = off + packetSize + + if ethHdrBuf != nil { + v := &iovec[iovecIdx] + v.Base = ðHdrBuf[0] + v.Len = uint64(len(ethHdrBuf)) + iovecIdx++ + } + + v := &iovec[iovecIdx] + hdrView := hdr.View() + v.Base = &hdrView[0] + v.Len = uint64(len(hdrView)) + iovecIdx++ + + for packetSize > 0 { + vec := &iovec[iovecIdx] + iovecIdx++ + + v := views[viewIdx] + vec.Base = &v[viewOff] + s := len(v) - viewOff + if s <= packetSize { + viewIdx++ + viewOff = 0 + } else { + s = packetSize + viewOff += s + } + vec.Len = uint64(s) + packetSize -= s + } + + mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + } + + packets := 0 + for packets < n { + sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } + return packets, nil +} + +// viewsEqual tests whether v1 and v2 refer to the same backing bytes. +func viewsEqual(vs1, vs2 []buffer.View) bool { + return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) } -// WriteRawPacket writes a raw packet directly to the file descriptor. -func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } @@ -468,9 +609,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } -// Inject injects an inbound packet. -func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound injects an inbound packet. +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 04406bc9a..2066987eb 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents tcpip.PacketBuffer } type context struct { @@ -92,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.ch <- packetInfo{remote, protocol, vv.ToView()} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -168,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -258,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.VectorisedView{}, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -293,11 +299,12 @@ func TestDeliverPacket(t *testing.T) { b[i] = uint8(rand.Intn(256)) } + var hdr header.Ethernet if !eth { // So that it looks like an IPv4 packet. b[0] = 0x40 } else { - hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr = make(header.Ethernet, header.EthernetMinimumSize) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, @@ -315,14 +322,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents.Data will be a single + // view, so make pi do the same for the + // DeepEqual check. + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 8bfeb97e4..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) + eth = header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -189,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 7ca217e5b..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} d.views = make([]buffer.View, len(BufConfig)) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { iovLen++ } d.iovecs = make([]syscall.Iovec, iovLen) @@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { func (d *readVDispatcher) allocateViews(bufConfig []int) { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { if err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // Skip virtioNetHdr which is added before each packet, it // isn't used and it isn't in a view. n -= virtioNetHdrSize @@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[0]) + eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -194,7 +198,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { } d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // virtioNetHdr is prepended before each packet. iovLen++ } @@ -225,7 +229,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { for k := 0; k < len(d.views); k++ { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -261,7 +265,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } if n <= d.e.hdrSize { @@ -271,9 +275,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -291,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 47a54845c..f35fcdff4 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -6,10 +6,11 @@ go_library( name = "loopback", srcs = ["loopback.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index b36629d2c..499cc608f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -23,6 +23,7 @@ package loopback import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,21 +71,45 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) return nil } -// Wait implements stack.LinkEndpoint.Wait. -func (*endpoint) Wait() {} +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { + return tcpip.ErrBadAddress + } + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + vv.TrimFront(len(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: vv, + LinkHeader: buffer.View(linkHeader), + }) + + return nil +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1bab380b0..1ac7948b6 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -7,9 +7,7 @@ go_library( name = "muxed", srcs = ["injectable.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 7c946101d..445b22c17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool { return m.dispatcher != nil } -// Inject implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound implements stack.InjectableLinkEndpoint. +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +} + +// WritePackets writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if +// r.RemoteAddress has a route registered in this endpoint. +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + endpoint, ok := m.routes[r.RemoteAddress] + if !ok { + return 0, tcpip.ErrNoRoute + } + return endpoint.WritePackets(r, gso, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } -// WriteRawPacket writes outbound packets to the appropriate +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // WriteRawPacket doesn't get a route or network address, so there's + // nowhere to write this. + return tcpip.ErrNoRoute +} + +// InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { endpoint, ok := m.routes[dest] if !ok { return tcpip.ErrNoRoute } - return endpoint.WriteRawPacket(dest, packet) + return endpoint.InjectOutbound(dest, packet) } // Wait implements stack.LinkEndpoint.Wait. diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3086fec00..63b249837 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -31,7 +31,7 @@ import ( func TestInjectableEndpointRawDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - endpoint.WriteRawPacket(dstIP, []byte{0xFA}) + endpoint.InjectOutbound(dstIP, []byte{0xFA}) buf := make([]byte, ipv4.MaxTotalSize) bytesRead, err := sock.Read(buf) @@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewView(0).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 2e8bc772a..d8211e93d 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -13,8 +13,9 @@ go_library( "rawfile_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@org_golang_x_sys//unix:go_default_library", ], - deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..0b5a6cf49 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 7e286a3a6..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -22,6 +22,7 @@ import ( "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { return nil } +// NonBlockingSendMMsg sends multiple messages on a socket. +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { + n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) + if e != 0 { + return 0, TranslateErrno(e) + } + + return int(n), nil +} + // PollEvent represents the pollfd structure passed to a poll() system call. type PollEvent struct { FD int32 diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 0a5ea3dc4..a4f9cdd69 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -12,9 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 330ed5e94..6b5bc542c 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -12,7 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index de1ce043d..8c9234d54 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -10,7 +10,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip/link/sharedmem/pipe", diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 9e71d4edf..080f9d667 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -199,10 +200,30 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } eth.Encode(ethHdr) - v := payload.ToView() + v := pkt.Data.ToView() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(pkt.Header.View(), v) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(v, buffer.View{}) e.mu.Unlock() if !ok { @@ -253,8 +274,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 0e9ba0846..89603c48f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() { } type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + addr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + vv buffer.VectorisedView + linkHeader buffer.View } type testContext struct { @@ -130,12 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: vv.Clone(nil), + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() @@ -272,7 +273,10 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -341,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -395,7 +401,10 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -410,7 +419,10 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -435,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -455,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -470,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -493,7 +514,10 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -509,7 +533,10 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -534,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -547,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: uu, + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -555,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 1756114e6..d6ae0368a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -9,9 +9,7 @@ go_library( "sniffer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index e401dce44..3392b7edd 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -49,6 +49,13 @@ var LogPackets uint32 = 1 // LogPacketsToFile must be accessed atomically. var LogPacketsToFile uint32 = 1 +var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{ + header.ICMPv4ProtocolNumber: header.IPv4MinimumSize, + header.ICMPv6ProtocolNumber: header.IPv6MinimumSize, + header.UDPProtocolNumber: header.UDPMinimumSize, + header.TCPProtocolNumber: header.TCPMinimumSize, +} + type endpoint struct { dispatcher stack.NetworkDispatcher lower stack.LinkEndpoint @@ -116,19 +123,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +154,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -193,22 +200,19 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket implements the stack.LinkEndpoint interface. It is called by -// higher-level protocols to write packets; it just logs the packet and forwards -// the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) + logPacket("send", protocol, pkt.Header.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() + hdrBuf := pkt.Header.View() + length := len(hdrBuf) + pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { panic(err) } if len(hdrBuf) > length { @@ -218,26 +222,75 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen panic(err) } length -= len(hdrBuf) - if length > 0 { - for _, v := range payload.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - break - } - } + logVectorisedView(pkt.Data, length, buf) + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) } + } +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.dumpPacket(gso, protocol, pkt) + return e.lower.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + view := pkts[0].Data.ToView() + for _, pkt := range pkts { + e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), + }) + } + return e.lower.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + length := vv.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + panic(err) + } + logVectorisedView(vv, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } - return e.lower.WritePacket(r, gso, hdr, payload, protocol) + return e.lower.WriteRawPacket(vv) +} + +func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { + if length <= 0 { + return + } + for _, v := range vv.Views() { + if len(v) > length { + v = v[:length] + } + n, err := buf.Write(v) + if err != nil { + panic(err) + } + length -= n + if length == 0 { + return + } + } } // Wait implements stack.LinkEndpoint.Wait. @@ -287,6 +340,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie return } + // We aren't guaranteed to have a transport header - it's possible for + // writes via raw endpoints to contain only network headers. + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) + return + } + // Figure out the transport layer info. transName := "unknown" srcPort := uint16(0) diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 92dce8fac..a71a493fc 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -6,7 +6,5 @@ go_library( name = "tun", srcs = ["tun_unsafe.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0746dc8ec..134837943 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -9,9 +9,7 @@ go_library( "waitable.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/gate", "//pkg/tcpip", diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 5a1791cb5..a8de38979 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -99,12 +99,36 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) + e.writeGate.Leave() + return err +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return len(pkts), nil + } + + n, err := e.lower.WritePackets(r, gso, pkts, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index ae23c96b7..31b11a27a 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,18 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.writeCount++ + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += len(pkts) + return len(pkts), nil +} + +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { e.writeCount++ return nil } @@ -78,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index df0d3a8c0..e7617229b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -7,9 +7,7 @@ go_library( name = "arp", srcs = ["arp.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6b1e854dc..da8482509 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -58,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -79,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return @@ -97,23 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -130,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil @@ -160,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -102,19 +102,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 2cad0a0b6..acf1e022c 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -25,7 +25,7 @@ go_library( "reassembler_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", @@ -44,11 +44,3 @@ go_test( embed = [":fragmentation"], deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f644a8b08..4144a7837 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -150,27 +150,36 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, @@ -230,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -270,7 +282,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -358,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -421,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -460,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -500,7 +523,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -510,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -561,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -608,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 58e537aad..aeddfcdd4 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,7 @@ go_library( "ipv4.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } @@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 5cd895ff0..e645cf62c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -47,7 +47,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -57,9 +57,9 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, @@ -89,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -117,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,71 +138,85 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payload.Size()) + length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be @@ -219,41 +234,71 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() return nil } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -267,7 +312,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) @@ -280,35 +325,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if pkt.Data.Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -316,10 +365,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -327,7 +376,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -340,11 +389,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 85ab0e3bc..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,10 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, @@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -177,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -187,17 +183,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -212,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -296,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -349,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -455,7 +446,7 @@ func TestInvalidFragments(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const nicid tcpip.NICID = 42 + const nicID tcpip.NICID = 42 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), @@ -465,10 +456,12 @@ func TestInvalidFragments(t *testing.T) { var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicid, sniffer.New(ep)) + s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f06622a8b..e4e273460 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,9 +10,7 @@ go_library( "ipv6.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7b638e9d0..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -21,21 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -const ( - // ndpHopLimit is the expected IP hop limit value of 255 for received - // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, - // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently - // drop the NDP packet. All outgoing NDP packets must use this value for - // its IP hop limit field. - ndpHopLimit = 255 -) - // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -61,19 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return @@ -81,16 +72,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V h := header.ICMPv6(v) iph := header.IPv6(netHeader) + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, header.ICMPv6RouterSolicit, header.ICMPv6RouterAdvert, header.ICMPv6RedirectMsg: - if iph.HopLimit() != ndpHopLimit { + if iph.HopLimit() != header.NDPHopLimit { + received.Invalid.Increment() + return + } + + if h.Code() != 0 { received.Invalid.Increment() return } @@ -104,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -114,10 +123,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: @@ -171,7 +180,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // rxNICID so the packet is processed as defined in RFC 4861, // as per RFC 4862 section 5.4.3. - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } @@ -180,9 +189,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) @@ -200,7 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // TODO(tamird/ghanan): there exists an explicit NDP option that is // used to update the neighbor table with link addresses for a @@ -209,7 +218,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // // Furthermore, the entirety of NDP handling here seems to be // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -217,7 +226,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -255,19 +266,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V return } - // At this point we know that the targetAddress is not tentaive + // At this point we know that the targetAddress is not tentative // on rxNICID. However, targetAddr may still be assigned to // rxNICID but not tentative (it could be permanent). Such a // scenario is beyond the scope of RFC 4862. As such, we simply // ignore such a scenario for now and proceed as normal. // - // TODO(b/140896005): Handle the scenario described above - // (inform the netstack integration that a duplicate address was - // was detected) + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: @@ -276,13 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -294,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -306,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() @@ -359,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: ndpHopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index dd3c4d7c4..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -30,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -55,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -65,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -131,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -143,11 +143,13 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: ndpHopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -274,20 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -359,3 +363,537 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index cd1e34085..e13f1fabf 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -43,7 +43,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -97,9 +97,8 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - length := uint16(hdr.UsedLength() + payload.Size()) +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, @@ -109,14 +108,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -124,36 +135,58 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + hdr := &pkts[i].Header + size := pkts[i].DataSize + ip := e.addIPHeader(r, hdr, size, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { - // TODO(b/119580726): Support IPv6 header-included packets. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. @@ -188,9 +221,9 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index deaa9b7f3..1cbfa7278 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stats := s.Stats().ICMP.V6PacketsReceived @@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := s.Stats().UDP.PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index e30791fe3..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -97,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -109,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -150,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) { // Receive the NDP packet with an invalid hop limit // value. - handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) // Invalid count should have increased. if got := invalid.Value(); got != 1 { @@ -164,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) { } // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) // Rx count of NDP packet of type typ.typ should have // increased. @@ -179,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..ab24372e7 --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..ad3cc24fa --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 11efb4e44..e156b01f6 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,7 @@ go_library( name = "ports", srcs = ["ports.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 30cea8996..6c5e19e8f 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -41,6 +41,30 @@ type portDescriptor struct { port uint16 } +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool +} + +func (f Flags) bits() reuseFlag { + var rf reuseFlag + if f.MostRecent { + rf |= mostRecentFlag + } + if f.LoadBalanced { + rf |= loadBalancedFlag + } + return rf +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -54,9 +78,59 @@ type PortManager struct { hint uint32 } +type reuseFlag int + +const ( + mostRecentFlag reuseFlag = 1 << iota + loadBalancedFlag + nextFlag + + flagMask = nextFlag - 1 +) + type portNode struct { - reuse bool - refs int + // refs stores the count for each possible flag combination. + refs [nextFlag]int +} + +func (p portNode) totalRefs() int { + var total int + for _, r := range p.refs { + total += r + } + return total +} + +// flagRefs returns the number of references with all specified flags. +func (p portNode) flagRefs(flags reuseFlag) int { + var total int + for i, r := range p.refs { + if reuseFlag(i)&flags == flags { + total += r + } + } + return total +} + +// allRefsHave returns if all references have all specified flags. +func (p portNode) allRefsHave(flags reuseFlag) bool { + for i, r := range p.refs { + if reuseFlag(i)&flags == flags && r > 0 { + return false + } + } + return true +} + +// intersectionRefs returns the set of flags shared by all references. +func (p portNode) intersectionRefs() reuseFlag { + intersection := flagMask + for i, r := range p.refs { + if r > 0 { + intersection &= reuseFlag(i) + } + } + return intersection } // deviceNode is never empty. When it has no elements, it is removed from the @@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all portNodes. If binding to a specific device, check // against the unspecified device and the provided device. -func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { +// +// If either of the port reuse flags is enabled on any of the nodes, all nodes +// sharing a port must share at least one reuse flag. This matches Linux's +// behavior. +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { + flagBits := flags.bits() if bindToDevice == 0 { // Trying to binding all devices. - if !reuse { + if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } + intersection := flagMask for _, p := range d { - if !p.reuse { - // Can't bind because the (addr,port) was previously bound without reuse. + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { + // Can't bind because the (addr,port) was + // previously bound without reuse. return false } } return true } + intersection := flagMask + if p, ok := d[0]; ok { - if !reuse || !p.reuse { + intersection = p.intersectionRefs() + if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - if !reuse || !p.reuse { + i := p.intersectionRefs() + intersection &= i + if intersection&flagBits == 0 { return false } } @@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(reuse, bindToDevice) { + if !d.isAvailable(flags, bindToDevice) { return false } } @@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice) { return false } } @@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } + flagBits := flags.bits() // Reserve port on all network protocols. for _, network := range networks { @@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - if n, ok := d[bindToDevice]; ok { - n.refs++ - d[bindToDevice] = n - } else { - d[bindToDevice] = portNode{reuse: reuse, refs: 1} - } + n := d[bindToDevice] + n.refs[flagBits]++ + d[bindToDevice] = n } return true @@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() + flagBits := flags.bits() + for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n.refs-- + n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == 0 { + if n.refs == [nextFlag]int{} { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 19f4833fc..d6969d050 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -33,7 +33,7 @@ type portReserveTestAction struct { port uint16 ip tcpip.Address want *tcpip.Error - reuse bool + flags Flags release bool device tcpip.NICID } @@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, }, }, { @@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) { /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true}, + {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 00, ip: fakeIPAddress, want: nil}, {port: 00, ip: fakeIPAddress, want: nil}, - {port: 00, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to ip with reuseport", actions: []portReserveTestAction{ - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, reuse: true, want: nil}, + {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "bind to inaddr any with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, - {port: 24, ip: anyIPAddress, reuse: true, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, release: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: anyIPAddress, release: true}, - {port: 24, ip: anyIPAddress, reuse: false, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, }, }, { tname: "bind twice with device fails", @@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { - tname: "bind with reuse", + tname: "bind with reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { - tname: "binding with reuse and device", + tname: "binding with reuseport and device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, }, }, { - tname: "mixing reuse and not reuse by binding to device", + tname: "mixing reuseport and not reuseport by binding to device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 999, want: nil}, }, }, { - tname: "can't bind to 0 after mixing reuse and not reuse", + tname: "can't bind to 0 after mixing reuseport and not reuseport", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil}, }, }, { - tname: "bind twice with reuse once", + tname: "bind twice with reuseport once", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, }, }, { tname: "release an unreserved device", actions: []portReserveTestAction{ - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil}, // The below don't exist. - {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, - {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, // Release all. - {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, - {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true}, + }, + }, { + tname: "bind with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind twice with reuseaddr once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr and reuseport twice, and then reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + }, + }, { + tname: "bind with reuseaddr, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuseport, and then reuseaddr and reuseport", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, }, } { @@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index a57752a7c..d7496fde6 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_connect", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index dad8ef399..875561566 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "tun_tcp_echo", srcs = ["main.go"], + visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", "//pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 29b7d761c..b31ddba2f 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -6,7 +6,5 @@ go_library( name = "seqnum", srcs = ["seqnum.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc03e90b..69077669a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -31,9 +31,7 @@ go_library( "transport_demuxer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/ilist", "//pkg/rand", @@ -73,6 +71,7 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) @@ -86,11 +85,3 @@ go_test( "//pkg/tcpip", ], ) - -filegroup( - name = "autogen", - srcs = [ - "linkaddrentry_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index bed60d7b1..90664ba8a 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -38,6 +38,34 @@ const ( // Default = 1s (from RFC 4861 section 10). defaultRetransmitTimer = time.Second + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + // + // Default = true. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements, as a host. + // + // Default = true. + defaultDiscoverDefaultRouters = true + + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + // + // Default = true. + defaultDiscoverOnLinkPrefixes = true + + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -49,8 +77,117 @@ const ( // // Min = 1ms. minimumRetransmitTimer = time.Millisecond + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + // + // Max = 10. + MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + // + // Max = 10. + MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 ) +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour +) + +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus will be called when the DAD process + // for an address (addr) on a NIC (with ID nicID) completes. resolved + // will be set to true if DAD completed successfully (no duplicate addr + // detected); false otherwise (addr was detected to be a duplicate on + // the link the NIC is a part of, or it was stopped for some other + // reason, such as the address being removed). If an error occured + // during DAD, err will be set and resolved must be ignored. + // + // This function is permitted to block indefinitely without interfering + // with the stack's operation. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true if the newly discovered + // router should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + + // OnDefaultRouterInvalidated will be called when a discovered default + // router that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + + // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + + // OnOnLinkPrefixInvalidated will be called when a discovered on-link + // prefix that was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Neighbor Solicitation messages to send when doing @@ -64,6 +201,31 @@ type NDPConfigurations struct { // // Must be greater than 0.5s. RetransmitTimer time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes + // will be discovered from Router Advertisements' Prefix Information + // option. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -72,6 +234,10 @@ func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ DupAddrDetectTransmits: defaultDupAddrDetectTransmits, RetransmitTimer: defaultRetransmitTimer, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -88,8 +254,24 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // The NIC this ndpState is for. + nic *NIC + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + // The DAD state to send the next NS message, or resolve the address. dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -105,13 +287,84 @@ type dadState struct { done *bool } +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement (RA). +type defaultRouterState struct { + invalidationTimer *time.Timer + + // Used to inform the timer not to invalidate the default router (R) in + // a race condition (T1 is a goroutine that handles an RA from R and T2 + // is the goroutine that handles R's invalidation timer firing): + // T1: Receive a new RA from R + // T1: Obtain the NIC's lock before processing the RA + // T2: R's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends R's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates R immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate R, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition (T1 is a goroutine that handles a PI for P and T2 + // is the goroutine that handles P's invalidation timer firing): + // T1: Receive a new PI for P + // T1: Obtain the NIC's lock before processing the PI + // T2: P's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends P's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates P immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate P, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to (n) MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported @@ -127,13 +380,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // reference count would have been increased without doing the // work that would have been done for an address that was brand // new. See NIC.addPermanentAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } - remaining := n.stack.ndpConfigs.DupAddrDetectTransmits + remaining := ndp.configs.DupAddrDetectTransmits { - done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref) + done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) if err != nil { return err } @@ -146,42 +399,59 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, var done bool var timer *time.Timer - timer = time.AfterFunc(n.stack.ndpConfigs.RetransmitTimer, func() { - n.mu.Lock() - defer n.mu.Unlock() + timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { + var d bool + var err *tcpip.Error - if done { - // If we reach this point, it means that the DAD timer - // fired after another goroutine already obtained the - // NIC lock and stopped DAD before it this function - // obtained the NIC lock. Simply return here and do - // nothing further. - return - } + // doDadIteration does a single iteration of the DAD loop. + // + // Returns true if the integrator needs to be informed of DAD + // completing. + doDadIteration := func() bool { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we are - // still performing DAD on it. If the endpoint does not - // exist, but we are doing DAD on it, then we started - // DAD at some point, but forgot to stop it when the - // endpoint was deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, n.ID())) - } + if done { + // If we reach this point, it means that the DAD + // timer fired after another goroutine already + // obtained the NIC lock and stopped DAD before + // this function obtained the NIC lock. Simply + // return here and do nothing further. + return false + } - if done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref); err != nil || done { - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, n.ID(), err) + ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] + if !ok { + // This should never happen. + // We should have an endpoint for addr since we + // are still performing DAD on it. If the + // endpoint does not exist, but we are doing DAD + // on it, then we started DAD at some point, but + // forgot to stop it when the endpoint was + // deleted. + panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) } - ndp.stopDuplicateAddressDetection(addr) - return - } + d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil || d { + delete(ndp.dad, addr) - timer.Reset(n.stack.ndpConfigs.RetransmitTimer) - remaining-- + if err != nil { + log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + // Let the integrator know DAD has completed. + return true + } + + remaining-- + timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + return false + } + + if doDadIteration() && ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + } }) ndp.dad[addr] = dadState{ @@ -204,11 +474,11 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, // The NIC that ndp belongs to (n) MUST be locked. // // Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { +func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { if ref.getKind() != permanentTentative { // The endpoint should still be marked as tentative // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, n.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } if remaining == 0 { @@ -219,17 +489,17 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem // Send a new NS. snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := n.endpoints[NetworkEndpointID{snmc}] + snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] if !ok { // This should never happen as if we have the // address, we should have the solicited-node // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", n.ID(), snmc, addr)) + panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) } // Use the unspecified address as the source address when performing // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, n.linkEP.LinkAddress(), snmcRef, false, false) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) @@ -239,7 +509,9 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: DefaultTOS}); err != nil { + if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return false, err } @@ -275,5 +547,611 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) - return + // Let the integrator know DAD did not resolve. + if ndp.nic.stack.ndpDisp != nil { + go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + return + } + + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation timer. + timer := rtr.invalidationTimer + + // We should ALWAYS have an invalidation timer for a + // discovered router. + if timer == nil { + panic("ndphandlera: RA invalidation timer should not be nil") + } + + if !timer.Stop() { + // If we reach this point, then we know the + // timer fired after we already took the NIC + // lock. Inform the timer not to invalidate the + // router when it obtains the lock as we just + // got a new RA that refreshes its lifetime to a + // non-zero value. See + // defaultRouterState.doNotInvalidate for more + // details. + *rtr.doNotInvalidate = true + } + + timer.Reset(rl) + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + + case header.NDPPrefixInformation: + prefix := opt.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) + } + + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) + } + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationTimer.Stop() + rtr.invalidationTimer = nil + *rtr.doNotInvalidate = true + rtr.doNotInvalidate = nil + + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition. See defaultRouterState.doNotInvalidate for more + // details. + var doNotInvalidate bool + + ndp.defaultRouters[ip] = defaultRouterState{ + invalidationTimer: time.AfterFunc(rl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if doNotInvalidate { + doNotInvalidate = false + return + } + + ndp.invalidateDefaultRouter(ip) + }), + doNotInvalidate: &doNotInvalidate, + } +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition. See onLinkPrefixState.doNotInvalidate for more + // details. + var doNotInvalidate bool + var timer *time.Timer + + // Only create a timer if the lifetime is not infinite. + if l < header.NDPInfiniteLifetime { + timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + } + + ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ + invalidationTimer: timer, + doNotInvalidate: &doNotInvalidate, + } +} + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + if s.invalidationTimer != nil { + s.invalidationTimer.Stop() + s.invalidationTimer = nil + *s.doNotInvalidate = true + } + + s.doNotInvalidate = nil + + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + } +} + +// prefixInvalidationCallback returns a new on-link prefix invalidation timer +// for prefix that fires after vl. +// +// doNotInvalidate is used to signal the timer when it fires at the same time +// that a prefix's valid lifetime gets refreshed. See +// onLinkPrefixState.doNotInvalidate for more details. +func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateOnLinkPrefix(prefix) + }) +} + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for _, ref := range ndp.nic.endpoints { + if ref.protocol != header.IPv6ProtocolNumber { + continue + } + + if ref.configType != slaac { + continue + } + + addr := ref.ep.ID().LocalAddress + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // + // At this point, we know we are refreshing a SLAAC generated + // IPv6 address with the prefix, prefix. Do the work as outlined + // by RFC 4862 section 5.5.3.e. + // + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) + } + + // TODO(b/143713887): Handle deprecating auto-generated address + // after the preferred lifetime. + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the + // address generated by SLAAC is as follows: + // + // 1) If the received Valid Lifetime is greater than 2 hours or + // greater than RemainingLifetime, set the valid lifetime of + // the address to the advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, + // ignore the advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 + // hours. + + // Handle the infinite valid lifetime separately as we do not + // keep a timer in this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it + // is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + ndp.autoGenAddresses[addr] = addrState + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, + // assume the remaining time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + ndp.autoGenAddresses[addr] = addrState + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) + ndp.autoGenAddresses[addr] = addrState + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + addrBytes := make([]byte, header.IPv6AddressSize) + copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return + } + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + }, FirstPrimaryEndpoint, permanent, slaac); err != nil { + panic(err) + } + + // Setup the timers to deprecate and invalidate this newly generated + // address. + + // TODO(b/143713887): Handle deprecating auto-generated addresses + // after the preferred lifetime. + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index cc789d70f..666f86c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,9 +15,12 @@ package stack_test import ( + "encoding/binary" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,12 +32,46 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - linkAddr1 = "\x01\x02\x03\x04\x05\x06" - linkAddr2 = "\x01\x02\x03\x04\x05\x07" + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + linkAddr1 = "\x02\x02\x03\x04\x05\x06" + linkAddr2 = "\x02\x02\x03\x04\x05\x07" + linkAddr3 = "\x02\x02\x03\x04\x05\x08" + defaultTimeout = 100 * time.Millisecond ) +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) +) + +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + var addr tcpip.AddressWithPrefix + if header.IsValidUnicastEthernetAddress(linkAddr) { + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + } + + return prefix, subnet, addr +} + // TestDADDisabled tests that an address successfully resolves immediately // when DAD is not enabled (the default for an empty stack.Options). func TestDADDisabled(t *testing.T) { @@ -42,7 +79,7 @@ func TestDADDisabled(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -68,6 +105,160 @@ func TestDADDisabled(t *testing.T) { } } +// ndpDADEvent is a set of parameters that was passed to +// ndpDispatcher.OnDuplicateAddressDetectionStatus. +type ndpDADEvent struct { + nicID tcpip.NICID + addr tcpip.Address + resolved bool + err *tcpip.Error +} + +type ndpRouterEvent struct { + nicID tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + +type ndpPrefixEvent struct { + nicID tcpip.NICID + prefix tcpip.Subnet + // true if prefix was discovered, false if invalidated. + discovered bool +} + +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + +var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) + +// ndpDispatcher implements NDPDispatcher so tests can know when various NDP +// related events happen for test purposes. +type ndpDispatcher struct { + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent +} + +// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicID, + addr, + resolved, + err, + } + } +} + +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + true, + } + } + + return n.rememberRouter +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ + nicID, + addr, + false, + } + } +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + true, + } + } + + return n.rememberPrefix +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ + nicID, + prefix, + false, + } + } +} + +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -88,9 +279,17 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits @@ -107,8 +306,7 @@ func TestDADResolve(t *testing.T) { stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - // Should have sent an NDP NS almost immediately. - time.Sleep(100 * time.Millisecond) + // Should have sent an NDP NS immediately. if got := stat.Value(); got != 1 { t.Fatalf("got NeighborSolicit = %d, want = 1", got) @@ -124,16 +322,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time - 500ms, to make sure - // the address is still not resolved. Note, we subtract - // 600ms because we already waited for 100ms earlier, - // so our remaining time is 100ms less than the expected - // time. - // (X - 100ms) - 500ms = X - 600ms - // - // TODO(b/140896005): Use events from the netstack to - // be signalled before DAD resolves. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - 600*time.Millisecond) + // Wait for the remaining time - some delta (500ms), to + // make sure the address is still not resolved. + const delta = 500 * time.Millisecond + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -142,13 +334,30 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) } - // Wait for the remaining time + 250ms, at which point - // the address should be resolved. Note, the remaining - // time is 500ms. See above comments. - // - // TODO(b/140896005): Use events from the netstack to - // know immediately when DAD completes. - time.Sleep(750 * time.Millisecond) + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) @@ -172,7 +381,8 @@ func TestDADResolve(t *testing.T) { } // Check NDP packet. - checker.IPv6(t, p.Header.ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1))) } @@ -250,13 +460,18 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.DefaultNDPConfigurations(), + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -279,15 +494,37 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } - // Wait 3 seconds to make sure that DAD did not resolve - time.Sleep(3 * time.Second) + // Wait for DAD to fail and make sure the address did + // not get resolved. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -302,13 +539,20 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, } - opts.NDPConfigs.RetransmitTimer = time.Second - opts.NDPConfigs.DupAddrDetectTransmits = 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -332,11 +576,27 @@ func TestDADStop(t *testing.T) { t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) } - // Wait for the time to normally resolve - // DupAddrDetectTransmits(2) * RetransmitTimer(1s) = 2s. - // An extra 250ms is added to make sure that if DAD was still running - // it resolves and the check below fails. - time.Sleep(2*time.Second + 250*time.Millisecond) + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + + } addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) @@ -350,3 +610,1448 @@ func TestDADStop(t *testing.T) { t.Fatalf("got NeighborSolicit = %d, want <= 1", got) } } + +// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NDP configurations using an invalid NICID. +func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) + } +} + +// TestSetNDPConfigurations tests that we can update and use per-interface NDP +// configurations without affecting the default NDP configurations or other +// interfaces' configurations. +func TestSetNDPConfigurations(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransmitTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + { + "OK", + 1, + time.Second, + time.Second, + }, + { + "Invalid Retransmit Timer", + 1, + 0, + time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + }) + + // This NIC(1)'s NDP configurations will be updated to + // be different from the default. + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Created before updating NIC(1)'s NDP configurations + // but updating NIC(1)'s NDP configurations should not + // affect other existing NICs. + if err := s.CreateNIC(2, e); err != nil { + t.Fatalf("CreateNIC(2) = %s", err) + } + + // Update the NDP configurations on NIC(1) to use DAD. + configs := stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + } + if err := s.SetNDPConfigurations(1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + } + + // Created after updating NIC(1)'s NDP configurations + // but the stack's default NDP configurations should not + // have been updated. + if err := s.CreateNIC(3, e); err != nil { + t.Fatalf("CreateNIC(3) = %s", err) + } + + // Add addresses for each NIC. + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + } + if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + } + + // Address should not be considered bound to NIC(1) yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should get the address on NIC(2) and NIC(3) + // immediately since we should not have performed DAD on + // it as the stack was configured to not do DAD by + // default and we only updated the NDP configurations on + // NIC(1). + addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + } + if addr.Address != addr2 { + t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + } + addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + } + if addr.Address != addr3 { + t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + } + + // Sleep until right (500ms before) before resolution to + // make sure the address didn't resolve on NIC(1) yet. + const delta = 500 * time.Millisecond + time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + } + }) + } +} + +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + ra := header.NDPRouterAdvert(pkt.NDPPayload()) + opts := ra.Options() + opts.Serialize(optSer) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} +} + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) +} + +// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix +// Information option. +// +// Note, raBufWithPI does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { + flags := uint8(0) + if onLink { + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 + } + + // A valid header.NDPPrefixInformation must be 30 bytes. + buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. + buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. + buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. + copy(buf[14:], prefix.Address) + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ + header.NDPPrefixInformation(buf[:]), + }) +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestRouterDiscovery(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + l3Lifetime := time.Duration(6) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + expectRouterEvent(llAddr3, true) + + // Rx an RA from lladdr2 with lesser lifetime. + l2Lifetime := time.Duration(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } + + // Wait for lladdr2's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + default: + } + } + } +} + +// TestNoPrefixDiscovery tests that prefix discovery will not be performed if +// configured not to. +func TestNoPrefixDiscovery(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + + // Being configured to discover prefixes means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // prefix discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) + + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + +// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered on-link prefix when the dispatcher asks it not to. +func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have received any prefix events") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +func TestPrefixDiscovery(t *testing.T) { + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) + + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) + + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) + + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation timer plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) +} + +func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { + // Update the infinite lifetime value to a smaller value so we can test + // that when we receive a PI with such a lifetime value, we do not + // invalidate the prefix. + const testInfiniteLifetimeSeconds = 2 + const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime + defer func() { + header.NDPInfiniteLifetime = saved + }() + + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() + + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } + + // Receive an RA with prefix in an NDP Prefix Information option (PI) + // with infinite valid lifetime which should not get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + expectPrefixEvent(subnet, true) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with finite lifetime. + // The prefix should get invalidated after 1s. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive an RA with finite lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + expectPrefixEvent(subnet, true) + + // Receive an RA with prefix with an infinite lifetime. + // The prefix should not be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with a prefix with a lifetime value greater than the + // set infinite lifetime value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + } + + // Receive an RA with 0 lifetime. + // The prefix should get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) + expectPrefixEvent(subnet, false) +} + +// TestPrefixDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + + // Receive an RA with 2 more than the max number of discovered on-link + // prefixes. + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} + prefixAddr[7] = byte(i) + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixAddr[:]), + PrefixLen: 64, + } + prefixes[i] = prefix.Subnet() + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = 128 + binary.BigEndian.PutUint32(buf[2:], 10) + copy(buf[14:], prefix.Address) + + optSer[i] = header.NDPPrefixInformation(buf[:]) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < stack.MaxDiscoveredOnLinkPrefixes { + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") + } + } else { + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + default: + } + } + } +} + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + default: + } + }) + } +} + +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 5 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set + // to test.evl. + // + + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + default: + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43f4ad91e..e8401c673 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -19,7 +19,6 @@ import ( "sync" "sync/atomic" - "gvisor.dev/gvisor/pkg/ilist" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,13 +36,20 @@ type NIC struct { mu sync.RWMutex spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint stats NICStats + // ndp is the NDP related state for NIC. + // + // Note, read and write operations on ndp require that the NIC is + // appropriately locked. ndp ndpState } @@ -78,16 +84,26 @@ const ( NeverPrimaryEndpoint ) +// newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { - return &NIC{ + // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For + // example, make sure that the link address it provides is a valid + // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + + nic := &NIC{ stack: stack, id: id, name: name, linkEP: ep, loopback: loopback, - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), + packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -99,9 +115,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - dad: make(map[tcpip.Address]dadState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), }, } + nic.ndp.nic = nic + + // Register supported packet endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + } + + return nic } // enable enables the NIC. enable will attach the link to its LinkEndpoint and @@ -126,11 +157,50 @@ func (n *NIC) enable() *tcpip.Error { // when we perform Duplicate Address Detection, or Router Advertisement // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 // section 4.2 for more information. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + // + // Also auto-generate an IPv6 link-local address based on the NIC's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] + if !ok { + return nil } - return nil + n.mu.Lock() + defer n.mu.Unlock() + + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + if !n.stack.autoGenIPv6LinkLocal { + return nil + } + + l2addr := n.linkEP.LinkAddress() + + // Only attempt to generate the link-local address if we have a + // valid MAC address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } + + addr := header.LinkLocalAddr(l2addr) + + _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint) + + return err } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -166,18 +236,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN n.mu.RLock() defer n.mu.RUnlock() - list := n.primary[protocol] - if list == nil { - return nil - } - - for e := list.Front(); e != nil; e = e.Next() { - r := e.(*referencedNetworkEndpoint) - // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set. - switch r.ep.ID().LocalAddress { - case header.IPv4Broadcast, header.IPv4Any: - continue - } + for _, r := range n.primary[protocol] { if r.isValidForOutgoing() && r.tryIncRef() { return r } @@ -186,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) } @@ -277,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static) n.mu.Unlock() return ref @@ -291,9 +364,31 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent. + // Promote the endpoint to become permanent and respect + // the new peb. if ref.tryIncRef() { ref.setKind(permanent) + + refs := n.primary[ref.protocol] + for i, r := range refs { + if r == ref { + switch peb { + case CanBePrimaryEndpoint: + return ref, nil + case FirstPrimaryEndpoint: + if i == 0 { + return ref, nil + } + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + case NeverPrimaryEndpoint: + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + return ref, nil + } + } + } + + n.insertPrimaryEndpointLocked(ref, peb) + return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -304,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) + return n.addAddressLocked(protocolAddress, peb, permanent, static) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -337,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, } // Set up cache if link address resolution exists for this protocol. @@ -362,22 +458,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar n.endpoints[id] = ref - l, ok := n.primary[protocolAddress.Protocol] - if !ok { - l = &ilist.List{} - n.primary[protocolAddress.Protocol] = l - } - - switch peb { - case CanBePrimaryEndpoint: - l.PushBack(ref) - case FirstPrimaryEndpoint: - l.PushFront(ref) - } + n.insertPrimaryEndpointLocked(ref, peb) // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(n, protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -430,8 +515,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for proto, list := range n.primary { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) + for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using // those. @@ -496,6 +580,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { return append(sns, n.addressRanges...) } +// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required +// by peb. +// +// n MUST be locked. +func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { + switch peb { + case CanBePrimaryEndpoint: + n.primary[r.protocol] = append(n.primary[r.protocol], r) + case FirstPrimaryEndpoint: + n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + } +} + func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() @@ -513,9 +610,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { } delete(n.endpoints, id) - wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() - if wasInList { - n.primary[r.protocol].Remove(r) + refs := n.primary[r.protocol] + for i, ref := range refs { + if ref == r { + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + break + } } r.ep.Close() @@ -540,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } } r.setKind(permanentExpired) @@ -585,6 +694,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // exists yet. Otherwise it just increments its count. n MUST be locked before // joinGroupLocked is called. func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + // TODO(b/143102137): When implementing MLD, make sure MLD packets are + // not sent unless a valid link-local address is available for use on n + // as an MLD packet's source address must be a link-local address as + // outlined in RFC 3810 section 5. + id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -635,10 +749,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -648,9 +762,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -658,19 +772,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } + // If no local link layer address is provided, assume it was sent + // directly to this NIC. + if local == "" { + local = n.linkEP.LinkAddress() + } + + // Are any packet sockets listening for this network protocol? + n.mu.RLock() + packetEPs := n.packetEPs[protocol] + // Check whether there are packet sockets listening for every protocol. + // If we received a packet with protocol EthernetProtocolAll, then the + // previous for loop will have handled it. + if protocol != header.EthernetProtocolAll { + packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + } + n.mu.RUnlock() + for _, ep := range packetEPs { + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -698,31 +832,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } } return } - n.stack.stats.IP.InvalidAddressesReceived.Increment() + // If a packet socket handled the packet, don't treat it as invalid. + if len(packetEPs) == 0 { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + } } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -734,41 +871,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -779,17 +916,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } @@ -838,6 +975,26 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return n.removePermanentAddressLocked(addr) } +// setNDPConfigs sets the NDP configurations for n. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNDPConfigs(c NDPConfigurations) { + c.validate() + + n.mu.Lock() + n.ndp.configs = c + n.mu.Unlock() +} + +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( @@ -857,7 +1014,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -873,8 +1030,50 @@ const ( temporary ) +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return tcpip.ErrNotSupported + } + n.packetEPs[netProto] = append(eps, ep) + + return nil +} + +func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return + } + + for i, epOther := range eps { + if epOther == ep { + n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + return + } + } +} + +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { - ilist.Entry ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -889,6 +1088,10 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 9d6157f22..61fd46d66 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,24 +60,64 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // this transport endpoint. It sets pkt.TransportHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw // transport protocol endpoints. RawTransportEndpoints receive the entire -// packet - including the link, network, and transport headers - as delivered -// to netstack. +// packet - including the network and transport headers - as delivered to +// netstack. type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) +} + +// PacketEndpoint is the interface that needs to be implemented by packet +// transport protocol endpoints. These endpoints receive link layer headers in +// addition to whatever they contain (usually network and transport layer +// headers and a payload). +type PacketEndpoint interface { + // HandlePacket is called by the stack when new packets arrive that + // match the endpoint. + // + // Implementers should treat packet as immutable and should copy it + // before before modification. + // + // linkHeader may have a length of 0, in which case the PacketEndpoint + // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of pkt. + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -107,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -125,13 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. + // + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -182,12 +232,17 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets to the given destination address and + // protocol. pkts must not be zero length. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -199,8 +254,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. - HandlePacket(r *Route, vv buffer.VectorisedView) + // this network endpoint. It sets pkt.NetworkHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -225,7 +282,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -242,9 +299,15 @@ type NetworkProtocol interface { // packets to the appropriate network endpoint after it has been handled by // the data link layer. type NetworkDispatcher interface { - // DeliverNetworkPacket finds the appropriate network protocol - // endpoint and hands the packet over for further processing. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // DeliverNetworkPacket finds the appropriate network protocol endpoint + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -266,12 +329,18 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityGSO + CapabilityHardwareGSO + + // CapabilitySoftwareGSO indicates the link endpoint supports of sending + // multiple packets using a single call (LinkEndpoint.WritePackets). + CapabilitySoftwareGSO ) // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -293,13 +362,27 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It sets pkt.LinkHeader if a link layer header exists. + // pkt.NetworkHeader and pkt.TransportHeader must have already been + // set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + + // WritePackets writes packets with the given protocol through the + // given route. pkts must not be zero length. + // + // Right now, WritePackets is used only when the software segmentation + // offload is enabled. If it will be used for something else, it may + // require to change syscall filters. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. The packet + // should already have an ethernet header. + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. @@ -324,13 +407,14 @@ type LinkEndpoint interface { type InjectableLinkEndpoint interface { LinkEndpoint - // Inject injects an inbound packet. - Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // InjectInbound injects an inbound packet. + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) - // WriteRawPacket writes a fully formed outbound packet directly to the link. + // InjectOutbound writes a fully formed outbound packet directly to the + // link. // // dest is used by endpoints with multiple raw destinations. - WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -359,10 +443,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -373,17 +457,22 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// UnassociatedEndpointFactory produces endpoints for writing packets not -// associated with a particular transport protocol. Such endpoints can be used -// to write arbitrary packets that include the IP header. -type UnassociatedEndpointFactory interface { - NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +// RawFactory produces endpoints for writing various types of raw packets. +type RawFactory interface { + // NewUnassociatedEndpoint produces endpoints for writing packets not + // associated with a particular transport protocol. Such endpoints can + // be used to write arbitrary packets that include the network header. + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + + // NewPacketEndpoint produces endpoints for reading and writing packets + // that include network and (when cooked is false) link layer headers. + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } // GSOType is the type of GSO segments. @@ -394,8 +483,14 @@ type GSOType int // Types of gso segments. const ( GSONone GSOType = iota + + // Hardware GSO types: GSOTCPv4 GSOTCPv6 + + // GSOSW is used for software GSO segments which have to be sent by + // endpoint.WritePackets. + GSOSW ) // GSO contains generic segmentation offload properties. @@ -423,3 +518,7 @@ type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 } + +// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. +// This isn't a hard limit, because it is never set into packet headers. +const SoftwareGSOMaxSize = (1 << 16) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e72373964..34307ae07 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -47,8 +46,8 @@ type Route struct { // starts. ref *referencedNetworkEndpoint - // loop controls where WritePacket should send packets. - loop PacketLooping + // Loop controls where WritePacket should send packets. + Loop PacketLooping } // makeRoute initializes a new route. It takes ownership of the provided @@ -69,7 +68,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip LocalLinkAddress: localLinkAddr, RemoteAddress: remoteAddr, ref: ref, - loop: loop, + Loop: loop, } } @@ -154,34 +153,54 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop) + err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } return err } +// WritePackets writes the set of packets through the given route. +func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { + if !r.ref.isValidForOutgoing() { + return 0, tcpip.ErrInvalidEndpointState + } + + n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + } + r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + payloadSize := 0 + for i := 0; i < n; i++ { + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) + payloadSize += pkts[i].DataSize + } + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a199bc1cc..0e88643a4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -50,7 +51,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -351,10 +359,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver - // unassociatedFactory creates unassociated endpoints. If nil, raw - // endpoints are disabled. It is set during Stack creation and is - // immutable. - unassociatedFactory UnassociatedEndpointFactory + // rawFactory creates raw endpoints. If nil, raw endpoints are + // disabled. It is set during Stack creation and is immutable. + rawFactory RawFactory demux *transportDemuxer @@ -362,9 +369,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -394,14 +402,31 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 - // ndpConfigs is the NDP configurations used by interfaces. + // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + + // autoGenIPv6LinkLocal determines whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -425,16 +450,35 @@ type Options struct { // stack (false). HandleLocal bool - // UnassociatedFactory produces unassociated endpoints raw endpoints. - // Raw endpoints are enabled only if this is non-nil. - UnassociatedFactory UnassociatedEndpointFactory + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID - // NDPConfigs is the NDP configurations used by interfaces. + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its // DupAddrDetectTransmits field, implying that DAD will not be performed // before assigning an address to a NIC. NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determins whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // Note, setting this to true does not mean that a link-local address + // will be assigned right away, or at all. If Duplicate Address + // Detection is enabled, an address will only be assigned if it + // successfully resolves. If it fails, no further attempt will be made + // to auto-generate an IPv6 link-local address. + // + // The generated link-local address will follow RFC 4291 Appendix A + // guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // RawFactory produces raw endpoints. Raw endpoints are enabled only if + // this is non-nil. + RawFactory RawFactory } // TransportEndpointInfo holds useful information about a transport endpoint @@ -481,22 +525,30 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + ndpConfigs: opts.NDPConfigs, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, + ndpDisp: opts.NDPDisp, } // Add specified network protocols. @@ -514,8 +566,8 @@ func New(opts Options) *Stack { } } - // Add the factory for unassociated endpoints, if present. - s.unassociatedFactory = opts.UnassociatedFactory + // Add the factory for raw endpoints, if present. + s.rawFactory = opts.RawFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -523,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -584,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -650,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if s.unassociatedFactory == nil { + if s.rawFactory == nil { return nil, tcpip.ErrNotPermitted } if !associated { - return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue) } t, ok := s.transportProtocols[transport] @@ -666,6 +723,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return t.proto.NewRawEndpoint(s, network, waiterQueue) } +// NewPacketEndpoint creates a new packet endpoint listening for the given +// netProto. +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if s.rawFactory == nil { + return nil, tcpip.ErrNotPermitted + } + + return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { @@ -988,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -1053,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1100,6 +1167,31 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1121,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1135,6 +1290,109 @@ func (s *Stack) Resume() { } } +// RegisterPacketEndpoint registers ep with the stack, causing it to receive +// all traffic of the specified netProto on the given NIC. If nicID is 0, it +// receives traffic from every NIC. +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + // If no NIC is specified, capture on all devices. + if nicID == 0 { + // Register with each NIC. + for _, nic := range s.nics { + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + s.unregisterPacketEndpointLocked(0, netProto, ep) + return err + } + } + return nil + } + + // Capture on a specific device. + nic, ok := s.nics[nicID] + if !ok { + return tcpip.ErrUnknownNICID + } + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + return err + } + + return nil +} + +// UnregisterPacketEndpoint unregisters ep for packets of the specified +// netProto from the specified NIC. If nicID is 0, ep is unregistered from all +// NICs. +func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + s.mu.Lock() + defer s.mu.Unlock() + s.unregisterPacketEndpointLocked(nicID, netProto, ep) +} + +func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + // If no NIC is specified, unregister on all devices. + if nicID == 0 { + // Unregister with each NIC. + for _, nic := range s.nics { + nic.unregisterPacketEndpoint(netProto, ep) + } + return + } + + // Unregister in a single device. + nic, ok := s.nics[nicID] + if !ok { + return + } + nic.unregisterPacketEndpoint(netProto, ep) +} + +// WritePacket writes data directly to the specified NIC. It adds an ethernet +// header based on the arguments. +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + // Add our own fake ethernet header. + ethFields := header.EthernetFields{ + SrcAddr: nic.linkEP.LinkAddress(), + DstAddr: dst, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) + + if err := nic.linkEP.WriteRawPacket(vv); err != nil { + return err + } + + return nil +} + +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + if err := nic.linkEP.WriteRawPacket(payload); err != nil { + return err + } + + return nil +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1286,12 +1544,47 @@ func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tc return nic.dupTentativeAddrDetected(addr) } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// SetNDPConfigurations sets the per-interface NDP configurations on the NIC +// with ID id to c. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.setNDPConfigs(c) + + return nil +} + +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) + + return nil +} + +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed } func generateRandUint32() uint32 { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 10fd1065f..8fc034ca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,11 +24,14 @@ import ( "sort" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -55,7 +58,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol @@ -68,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -83,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -119,32 +122,38 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) } if loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -189,9 +198,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, @@ -251,7 +260,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -261,7 +272,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -271,7 +284,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -280,7 +295,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -290,7 +307,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -310,7 +329,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -365,7 +387,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -660,11 +684,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -697,7 +721,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -709,7 +733,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -726,13 +750,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -740,7 +764,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -755,20 +779,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -783,10 +807,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -804,10 +828,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -823,10 +847,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -834,17 +858,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -1637,12 +1661,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1650,7 +1674,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1659,17 +1683,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1686,24 +1710,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1714,7 +1738,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1724,17 +1748,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1753,7 +1777,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1761,7 +1785,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1787,7 +1811,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1847,7 +1873,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { case <-ep2.C: @@ -1864,3 +1892,297 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } + +// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses +// (or lack there-of if disabled (default)). Note, DAD will be disabled in +// these tests. +func TestNICAutoGenAddr(t *testing.T) { + tests := []struct { + name string + autoGen bool + linkAddr tcpip.LinkAddress + shouldGen bool + }{ + { + "Disabled", + false, + linkAddr1, + false, + }, + { + "Enabled", + true, + linkAddr1, + true, + }, + { + "Nil MAC", + true, + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty MAC", + true, + tcpip.LinkAddress(""), + false, + }, + { + "Invalid MAC", + true, + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Multicast MAC", + true, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Unspecified MAC", + true, + tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + + if test.shouldGen { + // Should have auto-generated an address and + // resolved immediately (DAD is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + +// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected +// when an address's kind gets "promoted" to permanent from permanentExpired. +func TestNewPEBOnPromotionToPermanent(t *testing.T) { + pebs := []stack.PrimaryEndpointBehavior{ + stack.NeverPrimaryEndpoint, + stack.CanBePrimaryEndpoint, + stack.FirstPrimaryEndpoint, + } + + for _, pi := range pebs { + for _, ps := range pebs { + t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + // Add a permanent address with initial + // PrimaryEndpointBehavior (peb), pi. If pi is + // NeverPrimaryEndpoint, the address should not + // be returned by a call to GetMainNICAddress; + // else, it should. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + addr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if pi == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatalf("NewSubnet failed:", err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // Take a route through the address so its ref + // count gets incremented and does not actually + // get deleted when RemoveAddress is called + // below. This is because we want to test that a + // new peb is respected when an address gets + // "promoted" to permanent from a + // permanentExpired kind. + r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + defer r.Release() + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + + // + // At this point, the address should still be + // known by the NIC, but have its + // kind = permanentExpired. + // + + // Add some other address with peb set to + // FirstPrimaryEndpoint. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + + } + + // Add back the address we removed earlier and + // make sure the new peb was respected. + // (The address should just be promoted now). + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + var primaryAddrs []tcpip.Address + for _, pa := range s.NICInfo()[1].ProtocolAddresses { + primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) + } + var expectedList []tcpip.Address + switch ps { + case stack.FirstPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x01", + "\x03", + } + case stack.CanBePrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + "\x01", + } + case stack.NeverPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + } + } + if !cmp.Equal(primaryAddrs, expectedList) { + t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) + } + + // Once we remove the other address, if the new + // peb, ps, was NeverPrimaryEndpoint, no address + // should be returned by a call to + // GetMainNICAddress; else, our original address + // should be returned. + if err := s.RemoveAddress(1, "\x03"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + addr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if ps == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else { + if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + } + }) + } + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 92267ce4d..67c21be42 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,10 +17,10 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -41,6 +41,31 @@ type transportEndpoints struct { rawEndpoints []RawTransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNic, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + type endpointsByNic struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint @@ -48,9 +73,19 @@ type endpointsByNic struct { seed uint32 } +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -64,18 +99,17 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { - mpep.handlePacketAll(r, id, vv) + mpep.handlePacketAll(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -91,7 +125,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -127,21 +161,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T return len(epsByNic.endpoints) == 0 } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] - if !ok { - return - } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return - } - delete(eps.endpoints, id) -} - // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -183,14 +202,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. reuse bool } +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps +} + // reciprocalScale scales a value into range [0, n). // // This is similar to val % n, but faster. @@ -224,22 +256,40 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpointsArr[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy except for - // the final one. + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt) break } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. } +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { @@ -257,6 +307,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } return nil } @@ -330,9 +389,9 @@ var loopbackSubnet = func() tcpip.Subnet { }() // deliverPacket attempts to find one or more matching transport endpoints, and -// then, if matches are found, delivers the packet to them. Returns true if it -// found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +// then, if matches are found, delivers the packet to them. Returns true if +// the packet no longer needs to be handled. +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -341,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RLock() // Determine which transport endpoint or endpoints to deliver this packet to. - // If the packet is a broadcast or multicast, then find all matching - // transport endpoints. + // If the packet is a UDP broadcast or multicast, then find all matching + // transport endpoints. If the packet is a TCP packet with a non-unicast + // source or destination address, then do nothing further and instruct + // the caller to do the same. var destEps []*endpointsByNic - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, vv, id) - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { - destEps = append(destEps, ep) + switch protocol { + case header.UDPProtocolNumber: + if isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, id) + break + } + + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } + + case header.TCPProtocolNumber: + if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single + // source and a single destination; the addresses must + // be unicast. + eps.mu.RUnlock() + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true + } + + fallthrough + + default: + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } } eps.mu.RUnlock() @@ -361,17 +445,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.handlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -385,7 +471,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -395,7 +481,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -403,7 +489,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -412,12 +498,12 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // Deliver the packet. - ep.handleControlPacket(n, id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -445,14 +531,44 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv bu if ep, ok := eps.endpoints[nid]; ok { matchedEPs = append(matchedEPs, ep) } - return matchedEPs } +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } + } + + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + // findEndpointLocked returns the endpoint that most closely matches the given // id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { return matchedEPs[0] } return nil @@ -465,7 +581,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return nil + return tcpip.ErrNotSupported } eps.mu.Lock() @@ -496,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN func isMulticastOrBroadcast(addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) } + +func isUnicast(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 210233dc0..3b28b06d0 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -79,17 +79,17 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) linkEPs := make(map[string]*channel.Endpoint) for i, linkEpName := range linkEpNames { channelEP := channel.New(256, mtu, "") - nicid := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + nicID := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } linkEPs[linkEpName] = channelEP - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -156,7 +156,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -82,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -144,6 +148,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -192,7 +200,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -209,7 +217,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -218,15 +226,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -251,7 +259,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -266,7 +274,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -337,7 +345,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -346,7 +356,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -355,7 +367,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -408,7 +422,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -417,7 +433,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -426,7 +444,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -579,7 +599,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -598,10 +620,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 678a94616..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -231,6 +231,13 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -255,7 +262,7 @@ type FullAddress struct { // This may not be used by all endpoint types. NIC NICID - // Addr is the network address. + // Addr is the network or link layer address. Addr Address // Port is the transport port. @@ -301,7 +308,7 @@ type ControlMessages struct { // HasTimestamp indicates whether Timestamp is valid/set. HasTimestamp bool - // Timestamp is the time (in ns) that the last packed used to create + // Timestamp is the time (in ns) that the last packet used to create // the read data was received. Timestamp int64 @@ -310,6 +317,18 @@ type ControlMessages struct { // Inq is the number of bytes ready to be received. Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS int8 + + // HasTClass indicates whether Tclass is valid/set. + HasTClass bool + + // Tclass is the IPv6 traffic class of the associated packet. + TClass int32 } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -489,6 +508,11 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -501,11 +525,6 @@ type ErrorOption struct{} // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be -// sent out immediately by the transport protocol. For TCP, it determines if the -// Nagle algorithm is on or off. -type DelayOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -557,6 +576,11 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user +// specified timeout for a given TCP connection. +// See: RFC5482 for details. +type TCPUserTimeoutOption time.Duration + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string @@ -579,6 +603,16 @@ type MaxSegOption int // A zero value indicates the default. type TTLOption uint8 +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -673,6 +707,11 @@ func (s *StatCounter) Increment() { s.IncrementBy(1) } +// Decrement minuses one to the counter. +func (s *StatCounter) Decrement() { + s.IncrementBy(^uint64(0)) +} + // Value returns the current value of the counter. func (s *StatCounter) Value() uint64 { return atomic.LoadUint64(&s.count) @@ -881,6 +920,23 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // CurrentEstablished is the number of TCP connections for which the + // current state is either ESTABLISHED or CLOSE-WAIT. + CurrentEstablished *StatCounter + + // EstablishedResets is the number of times TCP connections have made + // a direct transition to the CLOSED state from either the + // ESTABLISHED state or the CLOSE-WAIT state. + EstablishedResets *StatCounter + + // EstablishedClosed is the number of times established TCP connections + // made a transition to CLOSED state. + EstablishedClosed *StatCounter + + // EstablishedTimedout is the number of times an established connection + // was reset because of keep-alive time out. + EstablishedTimedout *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter @@ -1308,8 +1364,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..48764b978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9254c3dea..d8c5b5058 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -38,11 +38,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3187b336b..9c40931b5 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -58,6 +55,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -274,13 +278,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } toCopy := *to @@ -291,7 +295,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -425,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + }) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -445,13 +453,17 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - icmpv6.SetChecksum(0) - icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -479,7 +491,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { case stateBound, stateConnected: @@ -488,11 +500,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -503,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -520,14 +532,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id e.route = r.Clone() - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -578,18 +590,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -714,18 +726,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -753,19 +765,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -776,7 +788,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -798,3 +810,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD new file mode 100644 index 000000000..44b58ff6b --- /dev/null +++ b/pkg/tcpip/transport/packet/BUILD @@ -0,0 +1,38 @@ +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "packet", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "packet", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "packet_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/iptables", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go new file mode 100644 index 000000000..0010b5e5f --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packet provides the implementation of packet sockets (see +// packet(7)). Packet sockets allow applications to: +// +// * manually write and inspect link, network, and transport headers +// * receive all traffic of a given network protocol, or all protocols +// +// Packet sockets are similar to raw sockets, but provide even more power to +// users, letting them effectively talk directly to the network device. +// +// Packet sockets skip the input and output iptables chains. +package packet + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal +// to have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + cooked bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +// NewEndpoint returns a new packet endpoint. +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + }, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { + return nil, err + } + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (ep *endpoint) ModerateRecvBuf(copied int) {} + +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + ep.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(b/129292371): Implement. + return 0, nil, tcpip.ErrInvalidOptionValue +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be +// disconnected, and this function always returns tpcip.ErrNotSupported. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be +// connected, and this function always returnes tcpip.ErrNotSupported. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used +// with Shutdown, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with +// Listen, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with +// Accept, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + // TODO(gvisor.dev/issue/173): Add Bind support. + + // "By default, all packets of the specified protocol type are passed + // to a packet socket. To get packets only from a specific interface + // use bind(2) specifying an address in a struct sockaddr_ll to bind + // the packet socket to an interface. Fields used for binding are + // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." + // - packet(7). + + return tcpip.ErrNotSupported +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be +// used with SetSockOpt, and this function always returns +// tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// HandlePacket implements stack.PacketEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + var packet packet + // TODO(b/129292371): Return network protocol. + if len(pkt.LinkHeader) > 0 { + // Get info directly from the ethernet header. + hdr := header.Ethernet(pkt.LinkHeader) + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(hdr.SourceAddress()), + } + } else { + // Guess the would-be ethernet header. + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(localAddr), + } + } + + if ep.cooked { + // Cooked packets can simply be queued. + packet.data = pkt.Data + } else { + // Raw packets need their ethernet headers prepended before + // queueing. + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + } + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(&packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + ep.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (ep *endpoint) Info() tcpip.EndpointInfo { + ep.mu.RLock() + // Make a copy of the endpoint info. + ret := ep.TransportEndpointInfo + ep.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (ep *endpoint) Stats() tcpip.EndpointStats { + return &ep.stats +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go new file mode 100644 index 000000000..9b88f17e4 --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -0,0 +1,72 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. + if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index fba598d51..00991ac8e 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) go_template_instance( - name = "packet_list", - out = "packet_list.go", + name = "raw_packet_list", + out = "raw_packet_list.go", package = "raw", - prefix = "packet", + prefix = "rawPacket", template = "//pkg/ilist:generic_list", types = { - "Element": "*packet", - "Linker": "*packet", + "Element": "*rawPacket", + "Linker": "*rawPacket", }, ) @@ -20,8 +20,8 @@ go_library( srcs = [ "endpoint.go", "endpoint_state.go", - "packet_list.go", "protocol.go", + "raw_packet_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], @@ -34,14 +34,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/packet", "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b4c660859..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -17,8 +17,7 @@ // // * manually write and inspect transport layer headers and payloads // * receive all traffic of a given transport protocol (e.g. ICMP or UDP) -// * optionally write and inspect network layer and link layer headers for -// packets +// * optionally write and inspect network layer headers of packets // // Raw sockets don't have any notion of ports, and incoming packets are // demultiplexed solely by protocol number. Thus, a raw UDP endpoint will @@ -38,15 +37,11 @@ import ( ) // +stateify savable -type packet struct { - packetEntry +type rawPacket struct { + rawPacketEntry // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -72,7 +67,7 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + rcvList rawPacketList rcvBufSizeMax int `state:".(int)"` rcvBufSize int rcvClosed bool @@ -90,7 +85,6 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } @@ -187,17 +181,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess return buffer.View{}, tcpip.ControlMessages{}, err } - packet := e.rcvList.Front() - e.rcvList.Remove(packet) - e.rcvBufSize -= packet.data.Size() + pkt := e.rcvList.Front() + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() e.rcvMu.Unlock() if addr != nil { - *addr = packet.senderAddr + *addr = pkt.senderAddr } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil + return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil } // Write implements tcpip.Endpoint.Write. @@ -344,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -561,7 +560,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -602,16 +601,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - packet := &packet{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() e.rcvList.PushBack(packet) @@ -643,3 +643,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index a6c7cc43a..33bfb56cd 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -20,15 +20,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// saveData saves packet.data field. -func (p *packet) saveData() buffer.VectorisedView { +// saveData saves rawPacket.data field. +func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } -// loadData loads packet.data field. -func (p *packet) loadData(data buffer.VectorisedView) { +// loadData loads rawPacket.data field. +func (p *rawPacket) loadData(data buffer.VectorisedView) { // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point @@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) { } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { - panic(err) + if ep.associated { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + panic(err) + } } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index a2512d666..f30aa2a4a 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,13 +17,19 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) -// EndpointFactory implements stack.UnassociatedEndpointFactory. +// EndpointFactory implements stack.RawFactory. type EndpointFactory struct{} -// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index aed70e06f..3b353d56c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -28,6 +28,7 @@ go_library( "forwarder.go", "protocol.go", "rcv.go", + "rcv_state.go", "reno.go", "sack.go", "sack_scoreboard.go", @@ -44,6 +45,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", @@ -51,6 +53,7 @@ go_library( "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", @@ -60,17 +63,9 @@ go_library( ], ) -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 844959fa0..5422ae80c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -241,8 +242,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() + // Now inherit any socket options that should be inherited from the + // listening endpoint. + // In case of Forwarder listenEP will be nil and hence this check. + if l.listenEP != nil { + l.listenEP.propagateInheritableOptions(n) + } + // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { n.Close() return nil, err } @@ -268,8 +276,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -288,7 +296,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -297,7 +305,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -349,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } } +// propagateInheritableOptions propagates any options set on the listening +// endpoint to the newly created endpoint. +func (e *endpoint) propagateInheritableOptions(n *endpoint) { + e.mu.Lock() + n.userTimeout = e.userTimeout + e.mu.Unlock() +} + // handleSynSegment is called in its own goroutine once the listening endpoint // receives a SYN segment. It is responsible for completing the handshake and // queueing the new endpoint for acceptance. @@ -359,6 +375,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -366,6 +383,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -399,8 +421,19 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case header.TCPFlagSyn: + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + + switch { + case s.flags == header.TCPFlagSyn: opts := parseSynSegmentOptions(s) if incSynRcvdCount() { // Only handle the syn if the following conditions hold @@ -433,19 +466,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } - case header.TCPFlagAck: + case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -459,6 +491,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } if !synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // // Send a reset as this is an ACK for which there is no // half open connections and we are not using cookies // yet. @@ -519,7 +559,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -530,6 +574,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5ea036bea..cdd69f360 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -78,9 +80,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -142,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -194,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.mu.Lock() + h.ep.workerCleanup = true + h.ep.mu.Unlock() + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. return tcpip.ErrConnectionRefused } return nil @@ -228,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -275,6 +312,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } + // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a + // sequence number outside of the window causes an ACK with the proper seq + // number and "After sending the acknowledgment, drop the unacceptable + // segment and return." + if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + return nil + } + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole @@ -319,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -445,7 +495,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -607,19 +657,14 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff - } - + hdr := &pkt.Header + packetSize := pkt.DataSize + off := pkt.DataOffset // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + pkt.TransportHeader = buffer.View(tcp) tcp.Encode(&header.TCPFields{ SrcPort: id.LocalPort, DstPort: id.RemotePort, @@ -631,7 +676,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { @@ -641,14 +686,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVV(data, xsum) + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } +} + +func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + // Allocate one big slice for all the headers. + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + buf := make([]byte, n*hdrSize) + pkts := make([]tcpip.PacketBuffer, n) + for i := range pkts { + pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) + } + + size := data.Size() + off := 0 + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + pkts[i].DataOffset = off + pkts[i].DataSize = packetSize + pkts[i].Data = data + buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) + off += packetSize + seq = seq.Add(seqnum.Size(packetSize)) + } + if ttl == 0 { + ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + } + + pkt := tcpip.PacketBuffer{ + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + DataOffset: 0, + DataSize: data.Size(), + Data: data, + } + buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) + if ttl == 0 { ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -754,10 +864,26 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError e.HardError = err - if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } @@ -771,6 +897,99 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.state == StateClose { + return + } + e.cleanupLocked() + e.state = StateClose + e.stack.Stats().TCP.EstablishedClosed.Increment() +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil { + replyWithReset(s) + s.decRef() + return + } + ep.(*endpoint).enqueueSegment(s) +} + +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + switch e.state { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.transitionToStateCloseLocked() + e.HardError = tcpip.ErrAborted + e.mu.Unlock() + return false, nil + default: + e.mu.Unlock() + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + // handleSegments pulls segments from the queue and processes them. It returns // no error if the protocol loop should continue, an error otherwise. func (e *endpoint) handleSegments() *tcpip.Error { @@ -788,14 +1007,34 @@ func (e *endpoint) handleSegments() *tcpip.Error { } if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset + if ok, err := e.handleReset(s); !ok { + return err } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -804,7 +1043,33 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return nil + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -830,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error { // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.mu.RLock() + userTimeout := e.userTimeout + e.mu.RUnlock() + e.keepalive.Lock() if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } @@ -854,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { // whether it is enabled for this endpoint. func (e *endpoint) resetKeepaliveTimer(receivedData bool) { e.keepalive.Lock() - defer e.keepalive.Unlock() if receivedData { e.keepalive.unacked = 0 } @@ -862,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { // data to send. if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() + e.keepalive.Unlock() return } if e.keepalive.unacked > 0 { @@ -869,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } else { e.keepalive.timer.enable(e.keepalive.idle) } + e.keepalive.Unlock() } // disableKeepaliveTimer stops the keepalive timer. @@ -903,7 +1182,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -932,21 +1210,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -954,7 +1217,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -985,13 +1247,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.transitionToStateCloseLocked() + e.mu.Unlock() + return nil }, }, { w: &e.snd.resendWaker, f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() return tcpip.ErrTimeout } return nil @@ -1028,17 +1297,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1054,12 +1324,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != StateError { + if e.state != StateClose && e.state != StateError { + // Only block the worker if the endpoint + // is not in closed state or error state. close(e.drainDone) <-e.undrain } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1086,15 +1364,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1110,15 +1389,168 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { - e.state = StateClose + e.stack.Stats().TCP.CurrentEstablished.Decrement() + e.transitionToStateCloseLocked() } + // Lock released below. epilogue() + // epilogue removes the endpoint from the transport-demuxer and + // unlocks e.mu. Now that no new segments can get enqueued to this + // endpoint, try to re-match the segment to a different endpoint + // as the current endpoint is closed. + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } + + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1b784b49..fe629aa40 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -121,6 +122,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -287,6 +293,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -319,6 +326,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -330,6 +342,11 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 // endpoints with v6only set to false, this could include multiple @@ -411,7 +428,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -458,6 +475,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -502,6 +525,36 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -550,6 +603,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption @@ -572,6 +626,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -659,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -671,14 +742,18 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -704,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -732,16 +805,19 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -752,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -1133,16 +1211,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1154,6 +1222,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1184,9 +1262,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -1206,7 +1284,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil @@ -1262,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.TCPUserTimeoutOption: + e.mu.Lock() + e.userTimeout = time.Duration(v) + e.mu.Unlock() + return nil + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -1319,6 +1403,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1345,6 +1451,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1357,8 +1464,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1379,13 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1496,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.keepalive.Unlock() return nil + case *tcpip.TCPUserTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.mu.Unlock() + return nil + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 @@ -1530,6 +1644,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1602,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1611,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1634,7 +1754,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -1649,7 +1769,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) if err != nil { return err } @@ -1664,7 +1784,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1678,15 +1798,18 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // reusePort is false below because connect cannot reuse a port even if // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: + // Port picking successful. Save the details of + // the selected port. e.ID = id + e.boundBindToDevice = e.bindToDevice return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1702,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1729,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1749,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1766,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1784,14 +1908,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1799,11 +1920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1844,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } + if e.state == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() @@ -1851,7 +1990,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { return err } @@ -1895,12 +2034,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1908,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. @@ -1932,26 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = flags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func(bindToDevice tcpip.NICID) { + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 e.ID.LocalAddress = "" e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } - }(e.bindToDevice) + }(e.boundPortFlags, e.boundBindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -2001,8 +2146,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -2025,6 +2170,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() @@ -2037,7 +2186,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2327,11 +2476,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2349,6 +2494,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2371,6 +2528,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -217,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index db40785d3..bc718064c 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,12 +55,23 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP // protocol. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +type DelayEnabled bool + // SendBufferSizeOption allows the default, min and max send buffer sizes for // TCP endpoints to be queried or configured. type SendBufferSizeOption struct { @@ -84,11 +96,14 @@ const ( type protocol struct { mu sync.Mutex sackEnabled bool + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -147,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo func replyWithReset(s *segment) { // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. if s.flagIsSet(header.TCPFlagAck) { seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) } - - ack := s.sequenceNumber.Add(s.logicalLen()) - - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -165,6 +193,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -202,6 +236,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -216,6 +268,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *DelayEnabled: + p.mu.Lock() + *v = DelayEnabled(p.delayEnabled) + p.mu.Unlock() + return nil + case *SendBufferSizeOption: p.mu.Lock() *v = p.sendBufferSize @@ -246,6 +304,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -258,5 +328,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..0a5534959 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -49,16 +50,20 @@ type receiver struct { pendingRcvdSegments segmentHeap pendingBufUsed seqnum.Size pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), - rcvWnd: rcvWnd, - rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), } } @@ -204,15 +209,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: - r.ep.state = StateClose + r.ep.transitionToStateCloseLocked() } r.ep.mu.Unlock() } @@ -253,25 +263,110 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { @@ -288,7 +383,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +410,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } @@ -99,8 +100,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8332a0179..8a947dc66 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -28,8 +28,11 @@ import ( ) const ( - // minRTO is the minimum allowed value for the retransmit timeout. - minRTO = 200 * time.Millisecond + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second // InitialCwnd is the initial congestion window. InitialCwnd = 10 @@ -134,6 +137,10 @@ type sender struct { // rttMeasureTime is the time when the rttMeasureSeqNum was sent. rttMeasureTime time.Time `state:".(unixTime)"` + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + closed bool writeNext *segment writeList segmentList @@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) { s.rto = s.rtt.srtt + 4*s.rtt.rttvar s.rtt.Unlock() - if s.rto < minRTO { - s.rto = minRTO + if s.rto < MinRTO { + s.rto = MinRTO } } @@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() - // Give up if we've waited more than a minute since the last resend. - if s.rto >= 60*time.Second { + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + s.ep.mu.RLock() + uto := s.ep.userTimeout + s.ep.mu.RUnlock() + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := MaxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + if remaining <= 0 || s.rto >= MaxRTO { return false } @@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool { // below. s.rto *= 2 + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. // // Retransmit timeouts: @@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) { // RFC 6298 Rule 5.3 if s.sndUna == s.sndNxt { s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() } } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 12eff8afc..8b20c3455 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) } + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6d022a266..e8fe4dab5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) { if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestConnectIncrementActiveConnection(t *testing.T) { @@ -206,17 +220,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +271,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +279,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -285,6 +307,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -296,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) } func TestTCPResetsReceivedIncrement(t *testing.T) { @@ -376,6 +403,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +430,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,15 +459,128 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), ), ) break } } +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(10 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) +} + func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -474,6 +633,352 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint not processing +// any receive when it is on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + // Wait for the endpoint state to be propagated. + time.Sleep(10 * time.Millisecond) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + c.CheckNoPacket("Packet received when listening socket was shutdown") +} + func TestTOSV4(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -635,6 +1140,71 @@ func TestConnectBindToDevice(t *testing.T) { } } +func TestRstOnSynSent(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + // Ensure that we've reached SynSent state + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Send a packet with a proper ACK and a RST flag to cause the socket + // to Error and close out + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -930,8 +1500,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1623,7 +2192,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1671,7 +2240,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1704,7 +2273,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1741,7 +2310,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } @@ -2211,6 +2780,13 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -2905,6 +3481,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -2912,12 +3495,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2955,10 +3538,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -2975,7 +3557,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2986,7 +3568,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2996,11 +3578,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -3773,8 +4355,9 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) @@ -3864,19 +4447,43 @@ func TestKeepalive(t *testing.T) { ) } + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.SeqNum(uint32(next)), - checker.AckNum(uint32(790)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), ), ) + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + } + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -3890,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki RcvWnd: 30000, }) - // Receive the SYN-ACK reply.w + // Receive the SYN-ACK reply. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss = seqnum.Value(tcp.SequenceNumber()) @@ -3923,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki return irs, iss } +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + // TestListenBacklogFull tests that netstack does not complete handshakes if the // listen backlog for the endpoint is full. func TestListenBacklogFull(t *testing.T) { @@ -4025,6 +4676,210 @@ func TestListenBacklogFull(t *testing.T) { } } +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + func TestListenSynRcvdQueueFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4056,7 +4911,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { SrcPort: context.TestPort, DstPort: context.StackPort, Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(789), + SeqNum: irs, RcvWnd: 30000, }) @@ -4167,7 +5022,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4207,6 +5063,125 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSynRcvdBadSeqNumber(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.AckNum(uint32(irs) + 1), + checker.SeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. + + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + + if err != nil && err != tcpip.ErrWouldBlock { + t.Fatalf("Accept failed: %s", err) + } + + if err == tcpip.ErrWouldBlock { + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + pkt := c.GetPacket() + tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + if string(tcpHdr.Payload()) != data { + t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } +} + func TestPassiveConnectionAttemptIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4381,6 +5356,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -4668,3 +5646,918 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, 0) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption int + }{ + {delayEnabled: false, wantDelayOption: 0}, + {delayEnabled: true, wantDelayOption: 1}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + } + gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + userTimeout := 50 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for a little over the minimum retransmit timeout of 200ms for + // the retransmitTimer to fire and close the connection. + time.Sleep(tcp.MinRTO + 10*time.Millisecond) + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 10 * time.Millisecond + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) + c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + + // Set userTimeout to be the duration for 3 keepalive probes. + userTimeout := 30 * time.Millisecond + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Now receive 2 keepalives, but don't ACK them. The connection should + // be reset when the 3rd one should be sent due to userTimeout being + // 30ms and each keepalive probe should be sent 10ms apart as set above after + // the connection has been idle for 10ms. + for i := 0; i < 2; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + 5*time.Millisecond) + + // The connection should be terminated after 30ms. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index 19b0d31c5..b33ec2087 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -8,7 +8,7 @@ go_library( srcs = ["context.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ - "//:sandbox", + "//visibility:public", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ef823e4ae..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -259,14 +260,15 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -302,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -319,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, + SrcAddr: src, + DstAddr: dst, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -337,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -350,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) +} + +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) } // SendAck sends an ACK packet. @@ -462,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -484,6 +508,13 @@ func (c *Context) GetV6Packet() []byte { // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of // the context. func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -494,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { PayloadLength: uint16(header.TCPMinimumSize + len(payload)), NextHeader: uint8(tcp.ProtocolNumber), HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. @@ -511,14 +542,16 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. @@ -1059,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) { func (c *Context) MSSWithoutOptions() uint16 { return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) } + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c9460aa0d..97e4d5825 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/iptables", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", @@ -59,11 +60,3 @@ go_test( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6e87245b7..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -31,9 +32,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -80,6 +78,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -106,6 +105,11 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 @@ -140,7 +144,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransProto: header.UDPProtocolNumber, }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership @@ -160,9 +164,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -171,8 +181,10 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} } for _, mem := range e.multicastMemberships { @@ -278,7 +290,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -286,20 +298,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block @@ -378,13 +390,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -396,7 +408,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, *to, netProto) if err != nil { return 0, nil, err } @@ -618,9 +630,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -728,6 +740,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.MulticastLoopOption(v) return nil + case *tcpip.ReuseAddressOption: + *o = 0 + return nil + case *tcpip.ReusePortOption: e.mu.RLock() v := e.reusePort @@ -809,7 +825,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: hdr, + Data: data, + TransportHeader: buffer.View(udp), + }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -859,7 +879,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { if e.state != StateConnected { return nil } - id := stack.TransportEndpointID{} + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -867,7 +890,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } @@ -875,13 +898,15 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.ID = id + e.boundBindToDevice = btd e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -903,7 +928,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: @@ -913,16 +938,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } @@ -950,20 +975,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) } e.ID = id + e.boundBindToDevice = btd e.route = r.Clone() e.dstPort = addr.Port - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -1018,20 +1044,27 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) + flags := ports.Flags{ + LoadBalanced: e.reusePort, + // FIXME(b/129164367): Support SO_REUSEADDR. + MostRecent: false, + } + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) if err != nil { - return id, err + return id, e.bindToDevice, err } + e.boundPortFlags = flags id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.boundPortFlags = ports.Flags{} } - return id, err + return id, e.bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1057,11 +1090,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -1070,13 +1103,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id - e.RegisterNICID = nicid + e.boundBindToDevice = btd + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -1111,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + return tcpip.FullAddress{ NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, + Addr: addr, Port: e.ID.LocalPort, }, nil } @@ -1154,17 +1193,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() @@ -1188,18 +1227,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1210,7 +1249,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. @@ -1234,6 +1273,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index b227e353b..43fb047ed 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // pass it to the reservation machinery. id := e.ID e.ID.LocalPort = 0 - e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d399ec722..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index de026880f..259c3072a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -130,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -151,12 +159,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -164,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b724d788c..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw if p.Proto != flow.netProto() { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) h := flow.header4Tuple(outgoing) checkers := append( @@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } // injectV4Packet creates a V4 test packet with the given payload and header @@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { @@ -1442,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1516,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1f7efb064..0427bc41f 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -34,11 +34,3 @@ go_test( ], embed = [":waiter"], ) - -filegroup( - name = "autogen", - srcs = [ - "waiter_list.go", - ], - visibility = ["//:sandbox"], -) |