diff options
Diffstat (limited to 'pkg')
396 files changed, 21995 insertions, 6555 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 114b516e2..05ca5342f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -23,11 +23,13 @@ go_library( "errors.go", "eventfd.go", "exec.go", + "fadvise.go", "fcntl.go", "file.go", "file_amd64.go", "file_arm64.go", "fs.go", + "fuse.go", "futex.go", "inotify.go", "ioctl.go", @@ -71,6 +73,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go index 3c6e0079d..86ee3f8b5 100644 --- a/pkg/abi/linux/aio.go +++ b/pkg/abi/linux/aio.go @@ -14,7 +14,63 @@ package linux +import "encoding/binary" + +// AIORingSize is sizeof(struct aio_ring). +const AIORingSize = 32 + +// I/O commands. const ( - // AIORingSize is sizeof(struct aio_ring). - AIORingSize = 32 + IOCB_CMD_PREAD = 0 + IOCB_CMD_PWRITE = 1 + IOCB_CMD_FSYNC = 2 + IOCB_CMD_FDSYNC = 3 + // 4 was the experimental IOCB_CMD_PREADX. + IOCB_CMD_POLL = 5 + IOCB_CMD_NOOP = 6 + IOCB_CMD_PREADV = 7 + IOCB_CMD_PWRITEV = 8 ) + +// I/O flags. +const ( + IOCB_FLAG_RESFD = 1 + IOCB_FLAG_IOPRIO = 2 +) + +// IOCallback describes an I/O request. +// +// The priority field is currently ignored in the implementation below. Also +// note that the IOCB_FLAG_RESFD feature is not supported. +type IOCallback struct { + Data uint64 + Key uint32 + _ uint32 + + OpCode uint16 + ReqPrio int16 + FD int32 + + Buf uint64 + Bytes uint64 + Offset int64 + + Reserved2 uint64 + Flags uint32 + + // eventfd to signal if IOCB_FLAG_RESFD is set in flags. + ResFD int32 +} + +// IOEvent describes an I/O result. +// +// +stateify savable +type IOEvent struct { + Data uint64 + Obj uint64 + Result int64 + Result2 int64 +} + +// IOEventSize is the size of an ioEvent encoded. +var IOEventSize = binary.Size(IOEvent{}) diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index fa3ae5f18..192e2093b 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -46,6 +46,10 @@ const ( // TTYAUX_MAJOR is the major device number for alternate TTY devices. TTYAUX_MAJOR = 5 + // MISC_MAJOR is the major device number for non-serial mice, misc feature + // devices. + MISC_MAJOR = 10 + // UNIX98_PTY_MASTER_MAJOR is the initial major device number for // Unix98 PTY masters. UNIX98_PTY_MASTER_MAJOR = 128 diff --git a/pkg/abi/linux/fadvise.go b/pkg/abi/linux/fadvise.go new file mode 100644 index 000000000..b06ff9964 --- /dev/null +++ b/pkg/abi/linux/fadvise.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + POSIX_FADV_NORMAL = 0 + POSIX_FADV_RANDOM = 1 + POSIX_FADV_SEQUENTIAL = 2 + POSIX_FADV_WILLNEED = 3 + POSIX_FADV_DONTNEED = 4 + POSIX_FADV_NOREUSE = 5 +) diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index 6663a199c..9242e80a5 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -55,7 +55,7 @@ type Flock struct { _ [4]byte } -// Flags for F_SETOWN_EX and F_GETOWN_EX. +// Owner types for F_SETOWN_EX and F_GETOWN_EX. const ( F_OWNER_TID = 0 F_OWNER_PID = 1 diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 055ac1d7c..e11ca2d62 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -191,8 +191,9 @@ var DirentType = abi.ValueSet{ // Values for preadv2/pwritev2. const ( - // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is - // accepted as a valid flag argument for preadv2/pwritev2. + // NOTE(b/120162627): gVisor does not implement the RWF_HIPRI feature, but + // the flag is accepted as a valid flag argument for preadv2/pwritev2 and + // silently ignored. RWF_HIPRI = 0x00000001 RWF_DSYNC = 0x00000002 RWF_SYNC = 0x00000004 diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go new file mode 100644 index 000000000..d3ebbccc4 --- /dev/null +++ b/pkg/abi/linux/fuse.go @@ -0,0 +1,143 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// +marshal +type FUSEOpcode uint32 + +// +marshal +type FUSEOpID uint64 + +// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +const ( + FUSE_LOOKUP FUSEOpcode = 1 + FUSE_FORGET = 2 /* no reply */ + FUSE_GETATTR = 3 + FUSE_SETATTR = 4 + FUSE_READLINK = 5 + FUSE_SYMLINK = 6 + _ + FUSE_MKNOD = 8 + FUSE_MKDIR = 9 + FUSE_UNLINK = 10 + FUSE_RMDIR = 11 + FUSE_RENAME = 12 + FUSE_LINK = 13 + FUSE_OPEN = 14 + FUSE_READ = 15 + FUSE_WRITE = 16 + FUSE_STATFS = 17 + FUSE_RELEASE = 18 + _ + FUSE_FSYNC = 20 + FUSE_SETXATTR = 21 + FUSE_GETXATTR = 22 + FUSE_LISTXATTR = 23 + FUSE_REMOVEXATTR = 24 + FUSE_FLUSH = 25 + FUSE_INIT = 26 + FUSE_OPENDIR = 27 + FUSE_READDIR = 28 + FUSE_RELEASEDIR = 29 + FUSE_FSYNCDIR = 30 + FUSE_GETLK = 31 + FUSE_SETLK = 32 + FUSE_SETLKW = 33 + FUSE_ACCESS = 34 + FUSE_CREATE = 35 + FUSE_INTERRUPT = 36 + FUSE_BMAP = 37 + FUSE_DESTROY = 38 + FUSE_IOCTL = 39 + FUSE_POLL = 40 + FUSE_NOTIFY_REPLY = 41 + FUSE_BATCH_FORGET = 42 +) + +const ( + // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem. + // This is the minimum size Linux supports. See linux.fuse.h. + FUSE_MIN_READ_BUFFER uint32 = 8192 +) + +// FUSEHeaderIn is the header read by the daemon with each request. +// +// +marshal +type FUSEHeaderIn struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Opcode specifies the kind of operation of the request. + Opcode FUSEOpcode + + // Unique specifies the unique identifier for this request. + Unique FUSEOpID + + // NodeID is the ID of the filesystem object being operated on. + NodeID uint64 + + // UID is the UID of the requesting process. + UID uint32 + + // GID is the GID of the requesting process. + GID uint32 + + // PID is the PID of the requesting process. + PID uint32 + + _ uint32 +} + +// FUSEHeaderOut is the header written by the daemon when it processes +// a request and wants to send a reply (almost all operations require a +// reply; if they do not, this will be explicitly documented). +// +// +marshal +type FUSEHeaderOut struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Error specifies the error that occurred (0 if none). + Error int32 + + // Unique specifies the unique identifier of the corresponding request. + Unique FUSEOpID +} + +// FUSEWriteIn is the header written by a daemon when it makes a +// write request to the FUSE filesystem. +// +// +marshal +type FUSEWriteIn struct { + // Fh specifies the file handle that is being written to. + Fh uint64 + + // Offset is the offset of the write. + Offset uint64 + + // Size is the size of data being written. + Size uint32 + + // WriteFlags is the flags used during the write. + WriteFlags uint32 + + // LockOwner is the ID of the lock owner. + LockOwner uint64 + + // Flags is the flags for the request. + Flags uint32 + + _ uint32 +} diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go index 08bfde3b5..8138088a6 100644 --- a/pkg/abi/linux/futex.go +++ b/pkg/abi/linux/futex.go @@ -60,3 +60,21 @@ const ( FUTEX_WAITERS = 0x80000000 FUTEX_OWNER_DIED = 0x40000000 ) + +// FUTEX_BITSET_MATCH_ANY has all bits set. +const FUTEX_BITSET_MATCH_ANY = 0xffffffff + +// ROBUST_LIST_LIMIT protects against a deliberately circular list. +const ROBUST_LIST_LIMIT = 2048 + +// RobustListHead corresponds to Linux's struct robust_list_head. +// +// +marshal +type RobustListHead struct { + List uint64 + FutexOffset uint64 + ListOpPending uint64 +} + +// SizeOfRobustListHead is the size of a RobustListHead struct. +var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes() diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..2c5e56ae5 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 7866352b4..0faf015c7 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -22,6 +22,8 @@ const ( ) // IFReq is an interface request. +// +// +marshal type IFReq struct { // IFName is an encoded name, normally null-terminated. This should be // accessed via the Name and SetName functions. @@ -79,6 +81,8 @@ type IFMap struct { // IFConf is used to return a list of interfaces and their addresses. See // netdevice(7) and struct ifconf for more detail on its use. +// +// +marshal type IFConf struct { Len int32 _ [4]byte // Pad to sizeof(struct ifconf). diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -134,6 +134,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from <linux/if_packet.h> +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 @@ -225,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -308,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go index 547921d0a..c85d786b9 100644 --- a/pkg/bpf/interpreter_test.go +++ b/pkg/bpf/interpreter_test.go @@ -767,7 +767,7 @@ func TestSimpleFilter(t *testing.T) { expectedRet: 0, }, { - desc: "Whitelisted syscall is allowed", + desc: "Allowed syscall is indeed allowed", seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, expectedRet: 0x7fff0000, }, diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 5f52cbe74..b094c5662 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { } } -// reader chunks reads and decompresses. -type reader struct { +// Reader is a compressed reader. +type Reader struct { pool // in is the source. in io.Reader } +var _ io.Reader = (*Reader)(nil) + // NewReader returns a new compressed reader. If key is non-nil, the data stream // is assumed to contain expected hash values, which will be compared against // hash values computed from the compressed bytes. See package comments for // details. -func NewReader(in io.Reader, key []byte) (io.Reader, error) { - r := &reader{ +func NewReader(in io.Reader, key []byte) (*Reader, error) { + r := &Reader{ in: in, } @@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready") // ErrHashMismatch is returned if the hash does not match. var ErrHashMismatch = errors.New("hash mismatch") +// ReadByte implements wire.Reader.ReadByte. +func (r *Reader) ReadByte() (byte, error) { + var p [1]byte + n, err := r.Read(p[:]) + if n != 1 { + return p[0], err + } + // Suppress EOF. + return p[0], nil +} + // Read implements io.Reader.Read. -func (r *reader) Read(p []byte) (int, error) { +func (r *Reader) Read(p []byte) (int, error) { r.mu.Lock() defer r.mu.Unlock() @@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) { return done, nil } -// writer chunks and schedules writes. -type writer struct { +// Writer is a compressed writer. +type Writer struct { pool // out is the underlying writer. @@ -562,6 +575,8 @@ type writer struct { closed bool } +var _ io.Writer = (*Writer)(nil) + // NewWriter returns a new compressed writer. If key is non-nil, hash values are // generated and written out for compressed bytes. See package comments for // details. @@ -569,8 +584,8 @@ type writer struct { // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { - w := &writer{ +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { + w := &Writer{ pool: pool{ chunkSize: chunkSize, buf: bufPool.Get().(*bytes.Buffer), @@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write } // flush writes a single buffer. -func (w *writer) flush(c *chunk) error { +func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) @@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error { return nil } +// WriteByte implements wire.Writer.WriteByte. +// +// Note that this implementation is necessary on the object itself, as an +// interface-based dispatch cannot tell whether the array backing the slice +// escapes, therefore the all bytes written will generate an escape. +func (w *Writer) WriteByte(b byte) error { + var p [1]byte + p[0] = b + n, err := w.Write(p[:]) + if n != 1 { + return err + } + return nil +} + // Write implements io.Writer.Write. -func (w *writer) Write(p []byte) (int, error) { +func (w *Writer) Write(p []byte) (int, error) { w.mu.Lock() defer w.mu.Unlock() @@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) { } // Close implements io.Closer.Close. -func (w *writer) Close() error { +func (w *Writer) Close() error { w.mu.Lock() defer w.mu.Unlock() diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go index 08381c1c0..ac7bb6774 100644 --- a/pkg/cpuid/cpuid_arm64.go +++ b/pkg/cpuid/cpuid_arm64.go @@ -312,8 +312,9 @@ func HostFeatureSet() *FeatureSet { } } -// Reads bogomips from host /proc/cpuinfo. Must run before whitelisting. -// This value is used to create the fake /proc/cpuinfo from a FeatureSet. +// Reads bogomips from host /proc/cpuinfo. Must run before syscall filter +// installation. This value is used to create the fake /proc/cpuinfo from a +// FeatureSet. func initCPUInfo() { cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo") if err != nil { diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index 562f8f405..17a89c00d 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -1057,9 +1057,9 @@ func HostFeatureSet() *FeatureSet { } } -// Reads max cpu frequency from host /proc/cpuinfo. Must run before -// whitelisting. This value is used to create the fake /proc/cpuinfo from a -// FeatureSet. +// Reads max cpu frequency from host /proc/cpuinfo. Must run before syscall +// filter installation. This value is used to create the fake /proc/cpuinfo +// from a FeatureSet. func initCPUFreq() { cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo") if err != nil { @@ -1106,7 +1106,6 @@ func initFeaturesFromString() { } func init() { - // initCpuFreq must be run before whitelists are enabled. initCPUFreq() initFeaturesFromString() } diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD index 798a65eca..35683fe98 100644 --- a/pkg/gohacks/BUILD +++ b/pkg/gohacks/BUILD @@ -7,5 +7,6 @@ go_library( srcs = [ "gohacks_unsafe.go", ], + stateify = False, visibility = ["//:sandbox"], ) diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 0d07da3b1..f4a4c33d3 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -90,7 +90,7 @@ func (l *List) Back() Element { // // NOTE: This is an O(n) operation. func (l *List) Len() (count int) { - for e := l.Front(); e != nil; e = e.Next() { + for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() { count++ } return count @@ -182,13 +182,13 @@ func (l *List) Remove(e Element) { if prev != nil { ElementMapper{}.linkerFor(prev).SetNext(next) - } else { + } else if l.head == e { l.head = next } if next != nil { ElementMapper{}.linkerFor(next).SetPrev(prev) - } else { + } else if l.tail == e { l.tail = prev } diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD new file mode 100644 index 000000000..eda82cfc1 --- /dev/null +++ b/pkg/iovec/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iovec", + srcs = ["iovec.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "iovec_test", + size = "small", + srcs = ["iovec_test.go"], + library = ":iovec", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go new file mode 100644 index 000000000..dd70fe80f --- /dev/null +++ b/pkg/iovec/iovec.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package iovec provides helpers to interact with vectorized I/O on host +// system. +package iovec + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// MaxIovs is the maximum number of iovecs host platform can accept. +var MaxIovs = linux.UIO_MAXIOV + +// Builder is a builder for slice of syscall.Iovec. +type Builder struct { + iovec []syscall.Iovec + storage [8]syscall.Iovec + + // overflow tracks the last buffer when iovec length is at MaxIovs. + overflow []byte +} + +// Add adds buf to b preparing to be written. Zero-length buf won't be added. +func (b *Builder) Add(buf []byte) { + if len(buf) == 0 { + return + } + if b.iovec == nil { + b.iovec = b.storage[:0] + } + if len(b.iovec) >= MaxIovs { + b.addByAppend(buf) + return + } + b.iovec = append(b.iovec, syscall.Iovec{ + Base: &buf[0], + Len: uint64(len(buf)), + }) + // Keep the last buf if iovec is at max capacity. We will need to append to it + // for later bufs. + if len(b.iovec) == MaxIovs { + n := len(buf) + b.overflow = buf[:n:n] + } +} + +func (b *Builder) addByAppend(buf []byte) { + b.overflow = append(b.overflow, buf...) + b.iovec[len(b.iovec)-1] = syscall.Iovec{ + Base: &b.overflow[0], + Len: uint64(len(b.overflow)), + } +} + +// Build returns the final Iovec slice. The length of returned iovec will not +// excceed MaxIovs. +func (b *Builder) Build() []syscall.Iovec { + return b.iovec +} diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go new file mode 100644 index 000000000..a3900c299 --- /dev/null +++ b/pkg/iovec/iovec_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package iovec + +import ( + "bytes" + "fmt" + "syscall" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func TestBuilderEmpty(t *testing.T) { + var builder Builder + iovecs := builder.Build() + if got, want := len(iovecs), 0; got != want { + t.Errorf("len(iovecs) = %d, want %d", got, want) + } +} + +func TestBuilderBuild(t *testing.T) { + a := []byte{1, 2} + b := []byte{3, 4, 5} + + var builder Builder + builder.Add(a) + builder.Add(b) + builder.Add(nil) // Nil slice won't be added. + builder.Add([]byte{}) // Empty slice won't be added. + iovecs := builder.Build() + + if got, want := len(iovecs), 2; got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } + for i, data := range [][]byte{a, b} { + if got, want := *iovecs[i].Base, data[0]; got != want { + t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want) + } + if got, want := iovecs[i].Len, uint64(len(data)); got != want { + t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want) + } + } +} + +func TestBuilderBuildMaxIov(t *testing.T) { + for _, test := range []struct { + numIov int + }{ + { + numIov: MaxIovs - 1, + }, + { + numIov: MaxIovs, + }, + { + numIov: MaxIovs + 1, + }, + { + numIov: MaxIovs + 10, + }, + } { + name := fmt.Sprintf("numIov=%v", test.numIov) + t.Run(name, func(t *testing.T) { + var data []byte + var builder Builder + for i := 0; i < test.numIov; i++ { + buf := []byte{byte(i)} + builder.Add(buf) + data = append(data, buf...) + } + iovec := builder.Build() + + // Check the expected length of iovec. + wantNum := test.numIov + if wantNum > MaxIovs { + wantNum = MaxIovs + } + if got, want := len(iovec), wantNum; got != want { + t.Errorf("len(iovec) = %d, want %d", got, want) + } + + // Test a real read-write. + var fds [2]int + if err := unix.Pipe(fds[:]); err != nil { + t.Fatalf("Pipe: %v", err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if int(wrote) != len(data) || e != 0 { + t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data)) + } + + got := make([]byte, len(data)) + if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil { + t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got)) + } + + if !bytes.Equal(got, data) { + t.Errorf("read: got data %v, want %v", got, data) + } + }) + } +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 57b89ad7d..2cb59f934 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -2506,7 +2506,7 @@ type msgFactory struct { var msgRegistry registry type registry struct { - factories [math.MaxUint8]msgFactory + factories [math.MaxUint8 + 1]msgFactory // largestFixedSize is computed so that given some message size M, you can // compute the maximum payload size (e.g. for Twrite, Rread) with diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 28d851ff5..122c457d2 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -1091,6 +1091,19 @@ type AllocateMode struct { Unshare bool } +// ToAllocateMode returns an AllocateMode from a fallocate(2) mode. +func ToAllocateMode(mode uint64) AllocateMode { + return AllocateMode{ + KeepSize: mode&unix.FALLOC_FL_KEEP_SIZE != 0, + PunchHole: mode&unix.FALLOC_FL_PUNCH_HOLE != 0, + NoHideStale: mode&unix.FALLOC_FL_NO_HIDE_STALE != 0, + CollapseRange: mode&unix.FALLOC_FL_COLLAPSE_RANGE != 0, + ZeroRange: mode&unix.FALLOC_FL_ZERO_RANGE != 0, + InsertRange: mode&unix.FALLOC_FL_INSERT_RANGE != 0, + Unshare: mode&unix.FALLOC_FL_UNSHARE_RANGE != 0, + } +} + // ToLinux converts to a value compatible with fallocate(2)'s mode. func (a *AllocateMode) ToLinux() uint32 { rv := uint32(0) diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index 06308cd29..a52dc1b4e 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -56,7 +56,7 @@ func (a AllowValue) String() (s string) { return fmt.Sprintf("%#x ", uintptr(a)) } -// Rule stores the whitelist of syscall arguments. +// Rule stores the allowed syscall arguments. // // For example: // rule := Rule { @@ -82,7 +82,7 @@ func (r Rule) String() (s string) { return } -// SyscallRules stores a map of OR'ed whitelist rules indexed by the syscall number. +// SyscallRules stores a map of OR'ed argument rules indexed by the syscall number. // If the 'Rules' is empty, we treat it as any argument is allowed. // // For example: diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index daba8b172..fd95eb2d2 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -28,7 +28,14 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs + + // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register. + TPIDR_EL0 uint64 +} const ( // SyscallWidth is the width of insturctions. @@ -101,9 +108,6 @@ type State struct { // Our floating point state. aarch64FPState `state:"wait"` - // TLS pointer - TPValue uint64 - // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -157,7 +161,6 @@ func (s *State) Fork() State { return State{ Regs: s.Regs, aarch64FPState: s.aarch64FPState.fork(), - TPValue: s.TPValue, FeatureSet: s.FeatureSet, OrigR0: s.OrigR0, } @@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers { return s.Regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } regs.UnmarshalUnsafe(buf) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // PtraceGetFPRegs implements Context.PtraceGetFPRegs. @@ -278,7 +281,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 3b3a0a272..1c3e3c14c 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and // u_debugreg, returning 0 or silently no-oping for other fields // respectively. - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) @@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { if addr&7 != 0 || addr >= userStructSize { return syscall.EIO } - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index ada7ac7b8..cabbf60e0 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) { // TLS returns the current TLS pointer. func (c *context64) TLS() uintptr { - return uintptr(c.TPValue) + return uintptr(c.Regs.TPIDR_EL0) } // SetTLS sets the current TLS pointer. Returns false if value is invalid. @@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool { return false } - c.TPValue = uint64(value) + c.Regs.TPIDR_EL0 = uint64(value) return true } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index dc458b37f..b9405b320 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -31,7 +31,11 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs +} // System-related constants for x86. const ( @@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers { return regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } @@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) { } regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // isUserSegmentSelector returns true if the given segment selector specifies a @@ -543,7 +547,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/control/logging.go b/pkg/sentry/control/logging.go index 811f24324..8a500a515 100644 --- a/pkg/sentry/control/logging.go +++ b/pkg/sentry/control/logging.go @@ -70,8 +70,8 @@ type LoggingArgs struct { type Logging struct{} // Change will change the log level and strace arguments. Although -// this functions signature requires an error it never acctually -// return san error. It's required by the URPC interface. +// this functions signature requires an error it never actually +// returns an error. It's required by the URPC interface. // Additionally, it may look odd that this is the only method // attached to an empty struct but this is also part of how // URPC dispatches. diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index 69e71e322..f45b2bd2b 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -188,6 +188,9 @@ type MultiDevice struct { // String stringifies MultiDevice. func (m *MultiDevice) String() string { + m.mu.Lock() + defer m.mu.Unlock() + buf := bytes.NewBuffer(nil) buf.WriteString("cache{") for k, v := range m.cache { diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD new file mode 100644 index 000000000..12e49b58a --- /dev/null +++ b/pkg/sentry/devices/ttydev/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "ttydev", + srcs = ["ttydev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/vfs", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go new file mode 100644 index 000000000..fbb7fd92c --- /dev/null +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ttydev implements devices for /dev/tty and (eventually) +// /dev/console. +// +// TODO(b/159623826): Support /dev/console. +package ttydev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +const ( + // See drivers/tty/tty_io.c:tty_init(). + ttyDevMinor = 0 + consoleDevMinor = 1 +) + +// ttyDevice implements vfs.Device for /dev/tty. +type ttyDevice struct{} + +// Open implements vfs.Device.Open. +func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &ttyFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// ttyFD implements vfs.FileDescriptionImpl for /dev/tty. +type ttyFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *ttyFD) Release() {} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return src.NumBytes(), nil +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return src.NumBytes(), nil +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, ttyDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "tty", + }) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "tty", vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD new file mode 100644 index 000000000..71c59287c --- /dev/null +++ b/pkg/sentry/devices/tundev/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "tundev", + srcs = ["tundev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/inet", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netstack", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/tcpip/link/tun", + "//pkg/usermem", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go new file mode 100644 index 000000000..dfbd069af --- /dev/null +++ b/pkg/sentry/devices/tundev/tundev.go @@ -0,0 +1,178 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tundev implements the /dev/net/tun device. +package tundev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// tunDevice implements vfs.Device for /dev/net/tun. +type tunDevice struct{} + +// Open implements vfs.Device.Open. +func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &tunFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +type tunFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + device tun.Device +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fd.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fd.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *tunFD) Release() { + fd.device.Release() +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *tunFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.Read(ctx, dst, opts) +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *tunFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + data, err := fd.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *tunFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.Write(ctx, src, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *tunFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fd.device.Write(data) +} + +// Readiness implements watier.Waitable.Readiness. +func (fd *tunFD) Readiness(mask waiter.EventMask) waiter.EventMask { + return fd.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fd *tunFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fd *tunFD) EventUnregister(e *waiter.Entry) { + fd.device.EventUnregister(e) +} + +// isNetTunSupported returns whether /dev/net/tun device is supported for s. +func isNetTunSupported(s inet.Stack) bool { + _, ok := s.(*netstack.Stack) + return ok +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, netTunDevMajor, netTunDevMinor, tunDevice{}, &vfs.RegisterDeviceOptions{}) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "net/tun", vfs.CharDevice, netTunDevMajor, netTunDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index beba0f771..f5537411e 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -160,6 +160,7 @@ type FileOperations interface { // refer. // // Preconditions: The AddressSpace (if any) that io refers to is activated. + // Must only be called from a task goroutine. Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) } diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index 084da2a8d..d41f30bbb 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -87,20 +87,6 @@ func RegisterFilesystem(f Filesystem) { filesystems.registered[f.Name()] = f } -// UnregisterFilesystem removes a file system from the global set. To keep the -// file system set compatible with save/restore, UnregisterFilesystem must be -// called before save/restore methods. -// -// For instance, packages may unregister their file system after it is mounted. -// This makes sense for pseudo file systems that should not be visible or -// mountable. See whitelistfs in fs/host/fs.go for one example. -func UnregisterFilesystem(name string) { - filesystems.mu.Lock() - defer filesystems.mu.Unlock() - - delete(filesystems.registered, name) -} - // FindFilesystem returns a Filesystem registered at name or (nil, false) if name // is not a file system type that can be found in /proc/filesystems. func FindFilesystem(name string) (Filesystem, bool) { diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 789369220..5fb419bcd 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -8,7 +8,6 @@ go_template_instance( out = "dirty_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "Dirty", @@ -25,14 +24,14 @@ go_template_instance( name = "frame_ref_set_impl", out = "frame_ref_set_impl.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "fsutil", prefix = "FrameRef", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "uint64", "Functions": "FrameRefSetFunctions", }, @@ -43,7 +42,6 @@ go_template_instance( out = "file_range_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "FileRange", @@ -86,7 +84,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/state", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index c6cd45087..2c9446c1d 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) { // repeatedly until all bytes have been written. max is the true size of the // cached object; offsets beyond max will not be passed to writeAt, even if // they are marked dirty. -func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { var changedDirty bool defer func() { if changedDirty { @@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet // successful partial write, SyncDirtyAll will call it repeatedly until all // bytes have been written. max is the true size of the cached object; offsets // beyond max will not be passed to writeAt, even if they are marked dirty. -func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { dseg := dirty.FirstSegment() for dseg.Ok() { if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil { @@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max } // Preconditions: mr must be page-aligned. -func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() { wbr := cseg.Range().Intersect(mr) if max < wbr.Start { diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 5643cdac9..bbafebf03 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -23,13 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a -// platform.File. It is used to implement Mappables that store data in +// memmap.File. It is used to implement Mappables that store data in // sparsely-allocated memory. // // type FileRangeSet <generated by go_generics> @@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli } // FileRange returns the FileRange mapped by seg. -func (seg FileRangeIterator) FileRange() platform.FileRange { +func (seg FileRangeIterator) FileRange() memmap.FileRange { return seg.FileRangeOf(seg.Range()) } // FileRangeOf returns the FileRange mapped by mr. // // Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. -func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange { +func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) - return platform.FileRange{frstart, frstart + mr.Length()} + return memmap.FileRange{frstart, frstart + mr.Length()} } // Fill attempts to ensure that all memmap.Mappable offsets in required are -// mapped to a platform.File offset, by allocating from mf with the given +// mapped to a memmap.File offset, by allocating from mf with the given // memory usage kind and invoking readAt to store data into memory. (If readAt // returns a successful partial read, Fill will call it repeatedly until all // bytes have been read.) EOF is handled consistently with the requirements of @@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map } // Drop removes segments for memmap.Mappable offsets in mr, freeing the -// corresponding platform.FileRanges. +// corresponding memmap.FileRanges. // // Preconditions: mr must be page-aligned. func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { @@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { } // DropAll removes all segments in mr, freeing the corresponding -// platform.FileRanges. +// memmap.FileRanges. func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { mf.DecRef(seg.FileRange()) diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index dd6f5aba6..a808894df 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -17,7 +17,7 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" ) @@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) { } // Merge implements segment.Functions.Merge. -func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) { +func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) { if val1 != val2 { return 0, false } @@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. } // Split implements segment.Functions.Split. -func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { +func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } // IncRefAndAccount adds a reference on the range fr. All newly inserted segments // are accounted as host page cache memory mappings. -func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) { seg, gap := refs.Find(fr.Start) for { switch { @@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { // DecRefAndAccount removes a reference on the range fr and untracks segments // that are removed from memory accounting. -func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) { seg := refs.FindSegment(fr.Start) for seg.Ok() && seg.Start() < fr.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e82afd112..ef0113b52 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { // offsets in fr or until the next call to UnmapAll. // // Preconditions: The caller must hold a reference on all offsets in fr. -func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) { +func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) { chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) } // Preconditions: f.mapsMu must be locked. -func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error { +func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error { prot := syscall.PROT_READ if write { prot |= syscall.PROT_WRITE diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 78fec553e..c15d8a946 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -21,18 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// HostMappable implements memmap.Mappable and platform.File over a +// HostMappable implements memmap.Mappable and memmap.File over a // CachedFileObject. // // Lock order (compare the lock order model in mm/mm.go): // truncateMu ("fs locks") // mu ("memmap.Mappable locks not taken by Translate") -// ("platform.File locks") +// ("memmap.File locks") // backingFile ("CachedFileObject locks") // // +stateify savable @@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error { return nil } -// MapInternal implements platform.File.MapInternal. -func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (h *HostMappable) FD() int { return h.backingFile.FD() } -// IncRef implements platform.File.IncRef. -func (h *HostMappable) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (h *HostMappable) IncRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.IncRefOn(mr) } -// DecRef implements platform.File.DecRef. -func (h *HostMappable) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (h *HostMappable) DecRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.DecRefOn(mr) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 800c8b4e1..fe8b0b6ac 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. c.mapsMu.Lock() @@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable } } -// IncRef implements platform.File.IncRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// IncRef implements memmap.File.IncRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { +func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg, gap := c.refs.Find(fr.Start) @@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { } } -// DecRef implements platform.File.DecRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// DecRef implements memmap.File.DecRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { +func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg := c.refs.FindSegment(fr.Start) @@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { c.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. This is used when we +// MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the -// platform.File during translation. -func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// memmap.File during translation. +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// FD implements memmap.File.FD. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. func (c *CachingInodeOperations) FD() int { return c.backingFile.FD() diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index aabce6cc9..d41d23a43 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 5c18dbd5e..905afb50d 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -17,15 +17,12 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index cb91355ab..82a02fcb2 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -308,9 +308,9 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e task := kernel.TaskFromContext(ctx) if task == nil { // No task? Linux does not have an analog for this case, but - // tty_check_change is more of a blacklist of cases than a - // whitelist, and is surprisingly permissive. Allowing the - // change seems most appropriate. + // tty_check_change only blocks specific cases and is + // surprisingly permissive. Allowing the change seems + // appropriate. return nil } diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index cf440dce8..93512c9b6 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -18,12 +18,12 @@ go_library( "//pkg/context", "//pkg/safemem", "//pkg/sentry/arch", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 9b0e0cca2..e6fda2b4f 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -117,7 +116,7 @@ type rootInode struct { kernfs.InodeNotSymlink kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 1d22adbe3..1081fff52 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -18,11 +18,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -35,7 +35,7 @@ type masterInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -67,8 +67,8 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf } // Stat implements kernfs.Inode.Stat. -func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := mi.InodeAttrs.Stat(vfsfs, opts) +func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -186,7 +186,17 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO // Stat implements vfs.FileDescriptionImpl.Stat. func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem() - return mfd.inode.Stat(fs, opts) + return mfd.inode.Stat(ctx, fs, opts) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence) } // maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid. diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index 7fe475080..a91cae3ef 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -18,10 +18,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -34,7 +34,7 @@ type slaveInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // Keep a reference to this inode's dentry. dentry kernfs.Dentry @@ -73,8 +73,8 @@ func (si *slaveInode) Valid(context.Context) bool { } // Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := si.InodeAttrs.Stat(vfsfs, opts) +func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen return sfd.inode.t.ld.outputQueueWrite(ctx, src) } -// Ioctl implements vfs.FileDescripionImpl.Ioctl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ @@ -183,5 +183,15 @@ func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOp // Stat implements vfs.FileDescriptionImpl.Stat. func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.Stat(fs, opts) + return sfd.inode.Stat(ctx, fs, opts) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence) } diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index 142ee53b0..d0e06cdc0 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -136,6 +136,8 @@ func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation { // CreateDeviceFile creates a device special file at the given pathname in the // devtmpfs instance accessed by the Accessor. func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error { + actx := a.wrapContext(ctx) + mode := (linux.FileMode)(perms) switch kind { case vfs.BlockDevice: @@ -145,12 +147,24 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v default: panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind)) } + + // Create any parent directories. See + // devtmpfs.c:handle_create()=>path_create(). + for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() { + pop := a.pathOperationAt(it.String()) + if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{ + Mode: 0755, + }); err != nil { + return fmt.Errorf("failed to create directory %q: %v", it.String(), err) + } + } + // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't // create, which it recognizes by storing a pointer to the kdevtmpfs struct // thread in struct inode::i_private. Accessor doesn't yet support deletion // of files at all, and probably won't as long as we don't need to support // kernel modules, so this is moot for now. - return a.vfsObj.MknodAt(a.wrapContext(ctx), a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{ + return a.vfsObj.MknodAt(actx, a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{ Mode: mode, DevMajor: major, DevMinor: minor, diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 973fa0def..abc610ef3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -54,13 +54,13 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/syscalls/linux", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", @@ -96,7 +96,7 @@ go_test( "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 6bd1a9fc6..55902322a 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -63,12 +63,17 @@ func (d *dentry) DecRef() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +// TODO(b/134676337): Implement inotify. +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. // -// TODO(gvisor.dev/issue/1479): Implement inotify. +// TODO(b/134676337): Implement inotify. func (d *dentry) Watches() *vfs.Watches { return nil } + +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +// +// TODO(b/134676337): Implement inotify. +func (d *dentry) OnZeroWatches() {} diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 43be6928a..357512c7e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -305,3 +306,13 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in fd.off = offset return offset, nil } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 5caaf14ed..30636cf66 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -55,7 +54,7 @@ type inode struct { // diskInode gives us access to the inode struct on disk. Immutable. diskInode disklayout.Inode - locks lock.FileLocks + locks vfs.FileLocks // This is immutable. The first field of the implementations must have inode // as the first field to ensure temporality. diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index 152036b2e..66d14bb95 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -149,3 +150,13 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index acb28d85b..62efd4095 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -66,6 +66,7 @@ func (in *inode) isSymlink() bool { // O_PATH. For this reason most of the functions return EBADF. type symlinkFD struct { fileDescription + vfs.NoLockFD } // Compiles only if symlinkFD implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD new file mode 100644 index 000000000..67649e811 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -0,0 +1,63 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "request_list", + out = "request_list.go", + package = "fuse", + prefix = "request", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Request", + "Linker": "*Request", + }, +) + +go_library( + name = "fuse", + srcs = [ + "connection.go", + "dev.go", + "fusefs.go", + "register.go", + "request_list.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/log", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "dev_test", + size = "small", + srcs = ["dev_test.go"], + library = ":fuse", + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/fsimpl/testutil", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", + ], +) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go new file mode 100644 index 000000000..f330da0bd --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -0,0 +1,255 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "errors" + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// MaxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const MaxActiveRequestsDefault = 10000 + +var ( + // Ordinary requests have even IDs, while interrupts IDs are odd. + InitReqBit uint64 = 1 + ReqIDStep uint64 = 2 +) + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// Connection is the struct by which the sentry communicates with the FUSE server daemon. +type Connection struct { + fd *DeviceFD + + // MaxWrite is the daemon's maximum size of a write buffer. + // This is negotiated during FUSE_INIT. + MaxWrite uint32 +} + +// NewFUSEConnection creates a FUSE connection to fd +func NewFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*Connection, error) { + // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to + // mount a FUSE filesystem. + fuseFD := fd.Impl().(*DeviceFD) + fuseFD.mounted = true + + // Create the writeBuf for the header to be stored in. + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + fuseFD.writeBuf = make([]byte, hdrLen) + fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) + fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests) + fuseFD.writeCursor = 0 + + return &Connection{ + fd: fuseFD, + }, nil +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *Connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(ReqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + hdr.MarshalUnsafe(buf[:hdrLen]) + payload.MarshalUnsafe(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// Call makes a request to the server and blocks the invoking task until a +// server responds with a response. +// NOTE: If no task is provided then the Call will simply enqueue the request +// and return a nil response. No blocking will happen in this case. Instead, +// this is used to signify that the processing of this request will happen by +// the kernel.Task that writes the response. See FUSE_INIT for such an +// invocation. +func (conn *Connection) Call(t *kernel.Task, r *Request) (*Response, error) { + fut, err := conn.callFuture(t, r) + if err != nil { + return nil, err + } + + return fut.resolve(t) +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + m.UnmarshalUnsafe(r.data[hdrLen:]) + return nil +} + +// callFuture makes a request to the server and returns a future response. +// Call resolve() when the response needs to be fulfilled. +func (conn *Connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + + // Is the queue full? + // + // We must busy wait here until the request can be queued. We don't + // block on the fd.fullQueueCh with a lock - so after being signalled, + // before we acquire the lock, it is possible that a barging task enters + // and queues a request. As a result, upon acquiring the lock we must + // again check if the room is available. + // + // This can potentially starve a request forever but this can only happen + // if there are always too many ongoing requests all the time. The + // supported maxActiveRequests setting should be really high to avoid this. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + if t == nil { + // Since there is no task that is waiting. We must error out. + return nil, errors.New("FUSE request queue full") + } + + log.Infof("Blocking request %v from being queued. Too many active requests: %v", + r.id, conn.fd.numActiveRequests) + conn.fd.mu.Unlock() + err := t.Block(conn.fd.fullQueueCh) + conn.fd.mu.Lock() + if err != nil { + return nil, err + } + } + + return conn.callFutureLocked(t, r) +} + +// callFutureLocked makes a request to the server and returns a future response. +func (conn *Connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.queue.PushBack(r) + conn.fd.numActiveRequests += 1 + fut := newFutureResponse(r.hdr.Opcode) + conn.fd.completions[r.id] = fut + + // Signal the readers that there is something to read. + conn.fd.waitQueue.Notify(waiter.EventIn) + + return fut, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse { + return &futureResponse{ + opcode: opcode, + ch: make(chan struct{}), + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // If there is no Task associated with this request - then we don't try to resolve + // the response. Instead, the task writing the response (proxy to the server) will + // process the response on our behalf. + if t == nil { + log.Infof("fuse.Response.resolve: Not waiting on a response from server.") + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go new file mode 100644 index 000000000..f3443ac71 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -0,0 +1,394 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const fuseDevMinor = 229 + +// fuseDevice implements vfs.Device for /dev/fuse. +type fuseDevice struct{} + +// Open implements vfs.Device.Open. +func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + if !kernel.FUSEEnabled { + return nil, syserror.ENOENT + } + + var fd DeviceFD + if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse. +type DeviceFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. + mounted bool + + // nextOpID is used to create new requests. + nextOpID linux.FUSEOpID + + // queue is the list of requests that need to be processed by the FUSE server. + queue requestList + + // numActiveRequests is the number of requests made by the Sentry that has + // yet to be responded to. + numActiveRequests uint64 + + // completions is used to map a request to its response. A Writer will use this + // to notify the caller of a completed response. + completions map[linux.FUSEOpID]*futureResponse + + writeCursor uint32 + + // writeBuf is the memory buffer used to copy in the FUSE out header from + // userspace. + writeBuf []byte + + // writeCursorFR current FR being copied from server. + writeCursorFR *futureResponse + + // mu protects all the queues, maps, buffers and cursors and nextOpID. + mu sync.Mutex + + // waitQueue is used to notify interested parties when the device becomes + // readable or writable. + waitQueue waiter.Queue + + // fullQueueCh is a channel used to synchronize the readers with the writers. + // Writers (inbound requests to the filesystem) block if there are too many + // unprocessed in-flight requests. + fullQueueCh chan struct{} + + // fs is the FUSE filesystem that this FD is being used for. + fs *filesystem +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *DeviceFD) Release() {} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + // We require that any Read done on this filesystem have a sane minimum + // read buffer. It must have the capacity for the fixed parts of any request + // header (Linux uses the request header and the FUSEWriteIn header for this + // calculation) + the negotiated MaxWrite room for the data. + minBuffSize := linux.FUSE_MIN_READ_BUFFER + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes()) + negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.MaxWrite + if minBuffSize < negotiatedMinBuffSize { + minBuffSize = negotiatedMinBuffSize + } + + // If the read buffer is too small, error out. + if dst.NumBytes() < int64(minBuffSize) { + return 0, syserror.EINVAL + } + + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readLocked(ctx, dst, opts) +} + +// readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + if fd.queue.Empty() { + return 0, syserror.ErrWouldBlock + } + + var readCursor uint32 + var bytesRead int64 + for { + req := fd.queue.Front() + if dst.NumBytes() < int64(req.hdr.Len) { + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } + + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req); err != nil { + return 0, err + } + + // We're done with this request. + fd.queue.Remove(req) + + // Restart the read as this request was invalid. + log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.") + return fd.readLocked(ctx, dst, opts) + } + + n, err := dst.CopyOut(ctx, req.data[readCursor:]) + if err != nil { + return 0, err + } + readCursor += uint32(n) + bytesRead += int64(n) + + if readCursor >= req.hdr.Len { + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + break + } + } + + return bytesRead, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.writeLocked(ctx, src, opts) +} + +// writeLocked implements writing to the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + var cn, n int64 + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + + for src.NumBytes() > 0 { + if fd.writeCursorFR != nil { + // Already have common header, and we're now copying the payload. + wantBytes := fd.writeCursorFR.hdr.Len + + // Note that the FR data doesn't have the header. Copy it over if its necessary. + if fd.writeCursorFR.data == nil { + fd.writeCursorFR.data = make([]byte, wantBytes) + } + + bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == wantBytes { + // Done reading this full response. Clean up and unblock the + // initiator. + break + } + + // Check if we have more data in src. + continue + } + + // Assert that the header isn't read into the writeBuf yet. + if fd.writeCursor >= hdrLen { + return 0, syserror.EINVAL + } + + // We don't have the full common response header yet. + wantBytes := hdrLen - fd.writeCursor + bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == hdrLen { + // Have full header in the writeBuf. Use it to fetch the actual futureResponse + // from the device's completions map. + var hdr linux.FUSEHeaderOut + hdr.UnmarshalBytes(fd.writeBuf) + + // We have the header now and so the writeBuf has served its purpose. + // We could reset it manually here but instead of doing that, at the + // end of the write, the writeCursor will be set to 0 thereby allowing + // the next request to overwrite whats in the buffer, + + fut, ok := fd.completions[hdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return 0, syserror.EINVAL + } + + delete(fd.completions, hdr.Unique) + + // Copy over the header into the future response. The rest of the payload + // will be copied over to the FR's data in the next iteration. + fut.hdr = &hdr + fd.writeCursorFR = fut + + // Next iteration will now try read the complete request, if src has + // any data remaining. Otherwise we're done. + } + } + + if fd.writeCursorFR != nil { + if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil { + return 0, err + } + + // Ready the device for the next request. + fd.writeCursorFR = nil + fd.writeCursor = 0 + } + + return n, nil +} + +// Readiness implements vfs.FileDescriptionImpl.Readiness. +func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + var ready waiter.EventMask + ready |= waiter.EventOut // FD is always writable + if !fd.queue.Empty() { + // Have reqs available, FD is readable. + ready |= waiter.EventIn + } + + return ready & mask +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.waitQueue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { + fd.waitQueue.EventUnregister(e) +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// sendResponse sends a response to the waiting task (if any). +func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { + // See if the running task need to perform some action before returning. + // Since we just finished writing the future, we can be sure that + // getResponse generates a populated response. + if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil { + return err + } + + // Signal that the queue is no longer full. + select { + case fd.fullQueueCh <- struct{}{}: + default: + } + fd.numActiveRequests -= 1 + + // Signal the task waiting on a response. + close(fut.ch) + return nil +} + +// sendError sends an error response to the waiting task (if any). +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error { + // Return the error to the calling task. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + respHdr := linux.FUSEHeaderOut{ + Len: outHdrLen, + Error: errno, + Unique: req.hdr.Unique, + } + + fut, ok := fd.completions[respHdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return syserror.EINVAL + } + delete(fd.completions, respHdr.Unique) + + fut.hdr = &respHdr + if err := fd.sendResponse(ctx, fut); err != nil { + return err + } + + return nil +} + +// noReceiverAction has the calling kernel.Task do some action if its known that no +// receiver is going to be waiting on the future channel. This is to be used by: +// FUSE_INIT. +func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error { + if r.opcode == linux.FUSE_INIT { + // TODO: process init response here. + // Maybe get the creds from the context? + // creds := auth.CredentialsFromContext(ctx) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go new file mode 100644 index 000000000..fcd77832a --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -0,0 +1,429 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "io" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// echoTestOpcode is the Opcode used during testing. The server used in tests +// will simply echo the payload back with the appropriate headers. +const echoTestOpcode linux.FUSEOpcode = 1000 + +type testPayload struct { + data uint32 +} + +// TestFUSECommunication tests that the communication layer between the Sentry and the +// FUSE server daemon works as expected. +func TestFUSECommunication(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + + // Create test cases with different number of concurrent clients and servers. + testCases := []struct { + Name string + NumClients int + NumServers int + MaxActiveRequests uint64 + }{ + { + Name: "SingleClientSingleServer", + NumClients: 1, + NumServers: 1, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "SingleClientMultipleServers", + NumClients: 1, + NumServers: 10, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "MultipleClientsSingleServer", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "MultipleClientsMultipleServers", + NumClients: 10, + NumServers: 10, + MaxActiveRequests: MaxActiveRequestsDefault, + }, + { + Name: "RequestCapacityFull", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: 1, + }, + { + Name: "RequestCapacityContinuouslyFull", + NumClients: 100, + NumServers: 2, + MaxActiveRequests: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + clientsDone := make([]chan struct{}, testCase.NumClients) + serversDone := make([]chan struct{}, testCase.NumServers) + serversKill := make([]chan struct{}, testCase.NumServers) + + // FUSE clients. + for i := 0; i < testCase.NumClients; i++ { + clientsDone[i] = make(chan struct{}) + go func(i int) { + fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i]) + }(i) + } + + // FUSE servers. + for j := 0; j < testCase.NumServers; j++ { + serversDone[j] = make(chan struct{}) + serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block. + go func(j int) { + fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j]) + }(j) + } + + // Tear down. + // + // Make sure all the clients are done. + for i := 0; i < testCase.NumClients; i++ { + <-clientsDone[i] + } + + // Kill any server that is potentially waiting. + for j := 0; j < testCase.NumServers; j++ { + serversKill[j] <- struct{}{} + } + + // Make sure all the servers are done. + for j := 0; j < testCase.NumServers; j++ { + <-serversDone[j] + } + }) + } +} + +// CallTest makes a request to the server and blocks the invoking +// goroutine until a server responds with a response. Doesn't block +// a kernel.Task. Analogous to Connection.Call but used for testing. +func CallTest(conn *Connection, t *kernel.Task, r *Request, i uint32) (*Response, error) { + conn.fd.mu.Lock() + + // Wait until we're certain that a new request can be processed. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + conn.fd.mu.Unlock() + select { + case <-conn.fd.fullQueueCh: + } + conn.fd.mu.Lock() + } + + fut, err := conn.callFutureLocked(t, r) // No task given. + conn.fd.mu.Unlock() + + if err != nil { + return nil, err + } + + // Resolve the response. + // + // Block without a task. + select { + case <-fut.ch: + } + + // A response is ready. Resolve and return it. + return fut.getResponse(), nil +} + +// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE +// device. However, it does so by - not blocking the task that is calling - and +// instead just waits on a channel. The behaviour is essentially the same as +// DeviceFD.Read except it guarantees that the task is not blocked. +func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) { + var err error + var n, total int64 + + dev := fd.Impl().(*DeviceFD) + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + dev.EventRegister(&w, waiter.EventIn) + for { + // Issue the request and break out if it completes with anything other than + // "would block". + n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) + total += n + if err != syserror.ErrWouldBlock { + break + } + + // Wait for a notification that we should retry. + // Emulate the blocking for when no requests are available + select { + case <-ch: + case <-killServer: + // Server killed by the main program. + return 0, true, nil + } + } + + dev.EventUnregister(&w) + return total, false, err +} + +// fuseClientRun emulates all the actions of a normal FUSE request. It creates +// a header, a payload, calls the server, waits for the response, and processes +// the response. +func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *Connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) { + defer func() { clientDone <- struct{}{} }() + + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + testObj := &testPayload{ + data: rand.Uint32(), + } + + req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + + // Queue up a request. + // Analogous to Call except it doesn't block on the task. + resp, err := CallTest(conn, clientTask, req, pid) + if err != nil { + t.Fatalf("CallTaskNonBlock failed: %v", err) + } + + if err = resp.Error(); err != nil { + t.Fatalf("Server responded with an error: %v", err) + } + + var respTestPayload testPayload + if err := resp.UnmarshalPayload(&respTestPayload); err != nil { + t.Fatalf("Unmarshalling payload error: %v", err) + } + + if resp.hdr.Unique != req.hdr.Unique { + t.Fatalf("got response for another request. Expected response for req %v but got response for req %v", + req.hdr.Unique, resp.hdr.Unique) + } + + if respTestPayload.data != testObj.data { + t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + } + +} + +// fuseServerRun creates a task and emulates all the actions of a simple FUSE server +// that simply reads a request and echos the same struct back as a response using the +// appropriate headers. +func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) { + defer func() { serverDone <- struct{}{} }() + + // Create the tasks that the server will be using. + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + var readPayload testPayload + + serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + + // Read the request. + for { + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + payloadLen := uint32(readPayload.SizeBytes()) + + // The raed buffer must meet some certain size criteria. + buffSize := inHdrLen + payloadLen + if buffSize < linux.FUSE_MIN_READ_BUFFER { + buffSize = linux.FUSE_MIN_READ_BUFFER + } + inBuf := make([]byte, buffSize) + inIOseq := usermem.BytesIOSequence(inBuf) + + n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer) + if err != nil { + t.Fatalf("Read failed :%v", err) + } + + // Server should shut down. No new requests are going to be made. + if serverKilled { + break + } + + if n <= 0 { + t.Fatalf("Read read no bytes") + } + + var readFUSEHeaderIn linux.FUSEHeaderIn + readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) + readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + + if readFUSEHeaderIn.Opcode != echoTestOpcode { + t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) + } + + // Write the response. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + outBuf := make([]byte, outHdrLen+payloadLen) + outHeader := linux.FUSEHeaderOut{ + Len: outHdrLen + payloadLen, + Error: 0, + Unique: readFUSEHeaderIn.Unique, + } + + // Echo the payload back. + outHeader.MarshalUnsafe(outBuf[:outHdrLen]) + readPayload.MarshalUnsafe(outBuf[outHdrLen:]) + outIOseq := usermem.BytesIOSequence(outBuf) + + n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed :%v", err) + } + } +} + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*Connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef() + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go new file mode 100644 index 000000000..911b6f7cb --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -0,0 +1,224 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fuse implements fusefs. +package fuse + +import ( + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Name is the default filesystem name. +const Name = "fuse" + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +type filesystemOptions struct { + // userID specifies the numeric uid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + userID uint32 + + // groupID specifies the numeric gid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + groupID uint32 + + // rootMode specifies the the file mode of the filesystem's root. + rootMode linux.FileMode + + // maxActiveRequests specifies the maximum number of active requests that can + // exist at any time. Any further requests will block when trying to + // Call the server. + maxActiveRequests uint64 +} + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // conn is used for communication between the FUSE server + // daemon and the sentry fusefs. + conn *Connection + + // opts is the options the fusefs is initialized with. + opts *filesystemOptions +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + var fsopts filesystemOptions + mopts := vfs.GenericParseMountOptions(opts.Data) + deviceDescriptorStr, ok := mopts["fd"] + if !ok { + log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + return nil, nil, syserror.EINVAL + } + delete(mopts, "fd") + + deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) + if err != nil { + return nil, nil, err + } + + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) + return nil, nil, syserror.EINVAL + } + fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + + // Parse and set all the other supported FUSE mount options. + // TODO(gVisor.dev/issue/3229): Expand the supported mount options. + if userIDStr, ok := mopts["user_id"]; ok { + delete(mopts, "user_id") + userID, err := strconv.ParseUint(userIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.userID = uint32(userID) + } + + if groupIDStr, ok := mopts["group_id"]; ok { + delete(mopts, "group_id") + groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.groupID = uint32(groupID) + } + + rootMode := linux.FileMode(0777) + modeStr, ok := mopts["rootmode"] + if ok { + delete(mopts, "rootmode") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) + return nil, nil, syserror.EINVAL + } + rootMode = linux.FileMode(mode) + } + fsopts.rootMode = rootMode + + // Set the maxInFlightRequests option. + fsopts.maxActiveRequests = MaxActiveRequestsDefault + + // Check for unparsed options. + if len(mopts) != 0 { + log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) + return nil, nil, syserror.EINVAL + } + + // Create a new FUSE filesystem. + fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + if err != nil { + log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) + return nil, nil, err + } + + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + // TODO: dispatch a FUSE_INIT request to the FUSE daemon server before + // returning. Mount will not block on this dispatched request. + + // root is the fusefs root directory. + root := fs.newInode(creds, fsopts.rootMode) + + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// NewFUSEFilesystem creates a new FUSE filesystem. +func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + } + + conn, err := NewFUSEConnection(ctx, device, opts.maxActiveRequests) + if err != nil { + log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) + return nil, syserror.EINVAL + } + + fs.conn = conn + fuseFD := device.Impl().(*DeviceFD) + fuseFD.fs = fs + + return fs, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + +// Inode implements kernfs.Inode. +type Inode struct { + kernfs.InodeAttrs + kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.OrderedChildren + + locks vfs.FileLocks + + dentry kernfs.Dentry +} + +func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &Inode{} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.dentry.Init(i) + + return &i.dentry +} + +// Open implements kernfs.Inode.Open. +func (i *Inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go new file mode 100644 index 000000000..b5b581152 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/register.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// Register registers the FUSE device with vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "misc", + }); err != nil { + return err + } + + return nil +} + +// CreateDevtmpfsFile creates a device special file in devtmpfs. +func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { + if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 5cdeeaeb5..4a800dcf9 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -69,7 +69,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index b98218753..8c7c8e1b3 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -85,6 +85,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { d2 := &dentry{ refs: 1, // held by d fs: d.fs, + ino: d.fs.nextSyntheticIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -138,6 +139,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.dirents = ds } + d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) if d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } @@ -183,13 +185,13 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { { Name: ".", Type: linux.DT_DIR, - Ino: d.ino, + Ino: uint64(d.ino), NextOff: 1, }, { Name: "..", Type: uint8(atomic.LoadUint32(&parent.mode) >> 12), - Ino: parent.ino, + Ino: uint64(parent.ino), NextOff: 2, }, } @@ -225,7 +227,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: p9d.QID.Path, + Ino: uint64(inoFromPath(p9d.QID.Path)), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or @@ -258,7 +260,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { dirents = append(dirents, vfs.Dirent{ Name: child.name, Type: uint8(atomic.LoadUint32(&child.mode) >> 12), - Ino: child.ino, + Ino: uint64(child.ino), NextOff: int64(len(dirents) + 1), }) } @@ -299,3 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in return 0, syserror.EINVAL } } + +// Sync implements vfs.FileDescriptionImpl.Sync. +func (fd *directoryFD) Sync(ctx context.Context) error { + return fd.dentry().handle.sync(ctx) +} diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 3c467e313..00e3c99cd 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -16,6 +16,7 @@ package gofer import ( "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -149,11 +150,9 @@ afterSymlink: return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -208,18 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil // Preconditions: As for getChildLocked. !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { - if !file.isNil() && qid.Path == child.ino { - // The file at this path hasn't changed. Just update cached - // metadata. + if !file.isNil() && inoFromPath(qid.Path) == child.ino { + // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. @@ -371,17 +380,33 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } parent.touchCMtime() parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if fs.opts.interop == InteropModeShared { - // The existence of a dentry at name would be inconclusive because the - // file it represents may have been deleted from the remote filesystem, - // so we would need to make an RPC to revalidate the dentry. Just - // attempt the file creation RPC instead. If a file does exist, the RPC - // will fail with EEXIST like we would have. If the RPC succeeds, and a - // stale dentry exists, the dentry will fail revalidation next time - // it's used. - return createInRemoteDir(parent, name) + if child := parent.children[name]; child != nil && child.isSynthetic() { + return syserror.EEXIST + } + // The existence of a non-synthetic dentry at name would be inconclusive + // because the file it represents may have been deleted from the remote + // filesystem, so we would need to make an RPC to revalidate the dentry. + // Just attempt the file creation RPC instead. If a file does exist, the + // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a + // stale dentry exists, the dentry will fail revalidation next time it's + // used. + if err := createInRemoteDir(parent, name); err != nil { + return err + } + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + return nil } if child := parent.children[name]; child != nil { return syserror.EEXIST @@ -397,6 +422,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } parent.touchCMtime() parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } @@ -443,21 +473,61 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b defer mntns.DecRef() parent.dirMu.Lock() defer parent.dirMu.Unlock() + child, ok := parent.children[name] if ok && child == nil { return syserror.ENOENT } - // We only need a dentry representing the file at name if it can be a mount - // point. If child is nil, then it can't be a mount point. If child is - // non-nil but stale, the actual file can't be a mount point either; we - // detect this case by just speculatively calling PrepareDeleteDentry and - // only revalidating the dentry if that fails (indicating that the existing - // dentry is a mount point). + + sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 + if sticky { + if !ok { + // If the sticky bit is set, we need to retrieve the child to determine + // whether removing it is allowed. + child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err + } + } else if child != nil && !child.cachedMetadataAuthoritative() { + // Make sure the dentry representing the file at name is up to date + // before examining its metadata. + child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) + if err != nil { + return err + } + } + if err := parent.mayDelete(rp.Credentials(), child); err != nil { + return err + } + } + + // If a child dentry exists, prepare to delete it. This should fail if it is + // a mount point. We detect mount points by speculatively calling + // PrepareDeleteDentry, which fails if child is a mount point. However, we + // may need to revalidate the file in this case to make sure that it has not + // been deleted or replaced on the remote fs, in which case the mount point + // will have disappeared. If calling PrepareDeleteDentry fails again on the + // up-to-date dentry, we can be sure that it is a mount point. + // + // Also note that if child is nil, then it can't be a mount point. if child != nil { + // Hold child.dirMu so we can check child.children and + // child.syntheticChildren. We don't access these fields until a bit later, + // but locking child.dirMu after calling vfs.PrepareDeleteDentry() would + // create an inconsistent lock ordering between dentry.dirMu and + // vfs.Dentry.mu (in the VFS lock order, it would make dentry.dirMu both "a + // FilesystemImpl lock" and "a lock acquired by a FilesystemImpl between + // PrepareDeleteDentry and CommitDeleteDentry). To avoid this, lock + // child.dirMu before calling PrepareDeleteDentry. child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - if parent.cachedMetadataAuthoritative() { + // We can skip revalidation in several cases: + // - We are not in InteropModeShared + // - The parent directory is synthetic, in which case the child must also + // be synthetic + // - We already updated the child during the sticky bit check above + if parent.cachedMetadataAuthoritative() || sticky { return err } child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) @@ -518,7 +588,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b if child == nil { return syserror.ENOENT } - } else { + } else if child == nil || !child.isSynthetic() { err = parent.file.unlinkAt(ctx, name, flags) if err != nil { if child != nil { @@ -527,6 +597,18 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return err } } + + // Generate inotify events for rmdir or unlink. + if dir { + parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + } else { + var cw *vfs.Watches + if child != nil { + cw = &child.watches + } + vfs.InotifyRemoveChild(cw, &parent.watches, name) + } + if child != nil { vfsObj.CommitDeleteDentry(&child.vfsd) child.setDeleted() @@ -764,15 +846,17 @@ afterTrailingSymlink: parent.dirMu.Unlock() return fd, err } + parent.dirMu.Unlock() if err != nil { - parent.dirMu.Unlock() return nil, err } - // Open existing child or follow symlink. - parent.dirMu.Unlock() if mustCreate { return nil, syserror.EEXIST } + if !child.isDir() && rp.MustBeDir() { + return nil, syserror.ENOTDIR + } + // Open existing child or follow symlink. if child.isSymlink() && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { @@ -793,11 +877,22 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } + + trunc := opts.Flags&linux.O_TRUNC != 0 && d.fileType() == linux.S_IFREG + if trunc { + // Lock metadataMu *while* we open a regular file with O_TRUNC because + // open(2) will change the file size on server. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + } + + var vfd *vfs.FileDescription + var err error mnt := rp.Mount() switch d.fileType() { case linux.S_IFREG: if !d.fs.opts.regularFilesUseSpecialFileFD { - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil { + if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil { return nil, err } fd := ®ularFileFD{} @@ -807,7 +902,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }); err != nil { return nil, err } - return &fd.vfsfd, nil + vfd = &fd.vfsfd } case linux.S_IFDIR: // Can't open directories with O_CREAT. @@ -847,7 +942,25 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks) } } - return d.openSpecialFileLocked(ctx, mnt, opts) + + if vfd == nil { + if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil { + return nil, err + } + } + + if trunc { + // If no errors occured so far then update file size in memory. This + // step is required even if !d.cachedMetadataAuthoritative() because + // d.mappings has to be updated. + // d.metadataMu has already been acquired if trunc == true. + d.updateFileSizeLocked(0) + + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + } + return vfd, err } func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { @@ -1013,6 +1126,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } + d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) return childVFSFD, nil } @@ -1064,7 +1178,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return err } } - if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + creds := rp.Credentials() + if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } vfsObj := rp.VirtualFilesystem() @@ -1079,12 +1194,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if renamed == nil { return syserror.ENOENT } + if err := oldParent.mayDelete(creds, renamed); err != nil { + return err + } if renamed.isDir() { if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { return syserror.EINVAL } if oldParent != newParent { - if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { return err } } @@ -1095,7 +1213,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if oldParent != newParent { - if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } newParent.dirMu.Lock() @@ -1193,10 +1311,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.cachedMetadataAuthoritative() { newParent.dirents = nil newParent.touchCMtime() - if renamed.isDir() { + if renamed.isDir() && (replaced == nil || !replaced.isDir()) { + // Increase the link count if we did not replace another directory. newParent.incLinks() } } + vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir()) return nil } @@ -1209,12 +1329,21 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()) + if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + fs.renameMuRUnlockAndCheckCaching(&ds) + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // StatAt implements vfs.FilesystemImpl.StatAt. @@ -1338,24 +1467,38 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.setxattr(ctx, rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckCaching(&ds) + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.removexattr(ctx, rp.Credentials(), name) + fs.renameMuRUnlockAndCheckCaching(&ds) + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // PrependPath implements vfs.FilesystemImpl.PrependPath. @@ -1364,3 +1507,7 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +func (fs *filesystem) nextSyntheticIno() inodeNumber { + return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 0d88a328e..e20de84b5 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -53,7 +53,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -111,6 +110,26 @@ type filesystem struct { syncMu sync.Mutex syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} + + // syntheticSeq stores a counter to used to generate unique inodeNumber for + // synthetic dentries. + syntheticSeq uint64 +} + +// inodeNumber represents inode number reported in Dirent.Ino. For regular +// dentries, it comes from QID.Path from the 9P server. Synthetic dentries +// have have their inodeNumber generated sequentially, with the MSB reserved to +// prevent conflicts with regular dentries. +type inodeNumber uint64 + +// Reserve MSB for synthetic mounts. +const syntheticInoMask = uint64(1) << 63 + +func inoFromPath(path uint64) inodeNumber { + if path&syntheticInoMask != 0 { + log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) + } + return inodeNumber(path &^ syntheticInoMask) } type filesystemOptions struct { @@ -583,21 +602,27 @@ type dentry struct { // returned by the server. dirents is protected by dirMu. dirents []vfs.Dirent - // Cached metadata; protected by metadataMu and accessed using atomic - // memory operations unless otherwise specified. + // Cached metadata; protected by metadataMu. + // To access: + // - In situations where consistency is not required (like stat), these + // can be accessed using atomic operations only (without locking). + // - Lock metadataMu and can access without atomic operations. + // To mutate: + // - Lock metadataMu and use atomic operations to update because we might + // have atomic readers that don't hold the lock. metadataMu sync.Mutex - ino uint64 // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + ino inodeNumber // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 ctime int64 btime int64 // File size, protected by both metadataMu and dataMu (i.e. both must be - // locked to mutate it). + // locked to mutate it; locking either is sufficient to access it). size uint64 // nlink counts the number of hard links to this dentry. It's updated and @@ -665,7 +690,10 @@ type dentry struct { // endpoint bound to this file. pipe *pipe.VFSPipe - locks lock.FileLocks + locks vfs.FileLocks + + // Inotify watches for this dentry. + watches vfs.Watches } // dentryAttrMask returns a p9.AttrMask enabling all attributes used by the @@ -702,7 +730,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, file: file, - ino: qid.Path, + ino: inoFromPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -757,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -792,11 +820,8 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.dataMu.Lock() - atomic.StoreUint64(&d.size, attr.Size) - d.dataMu.Unlock() + d.updateFileSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() @@ -808,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() if !d.handle.file.isNil() { file = d.handle.file @@ -823,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } @@ -835,10 +864,18 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME stat.Blksize = atomic.LoadUint32(&d.blockSize) stat.Nlink = atomic.LoadUint32(&d.nlink) + if stat.Nlink == 0 { + // The remote filesystem doesn't support link count; just make + // something up. This is consistent with Linux, where + // fs/inode.c:inode_init_always() initializes link count to 1, and + // fs/9p/vfs_inode_dotl.c:v9fs_stat2inode_dotl() doesn't touch it if + // it's not provided by the remote filesystem. + stat.Nlink = 1 + } stat.UID = atomic.LoadUint32(&d.uid) stat.GID = atomic.LoadUint32(&d.gid) stat.Mode = uint16(atomic.LoadUint32(&d.mode)) - stat.Ino = d.ino + stat.Ino = uint64(d.ino) stat.Size = atomic.LoadUint64(&d.size) // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. @@ -851,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.DevMinor = d.fs.devMinor } -func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { +func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -859,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { @@ -876,14 +914,14 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // Prepare for truncate. if stat.Mask&linux.STATX_SIZE != 0 { - switch d.mode & linux.S_IFMT { - case linux.S_IFREG: + switch mode.FileType() { + case linux.ModeRegular: if !setLocalMtime { // Truncate updates mtime. setLocalMtime = true stat.Mtime.Nsec = linux.UTIME_NOW } - case linux.S_IFDIR: + case linux.ModeDirectory: return syserror.EISDIR default: return syserror.EINVAL @@ -892,8 +930,25 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } d.metadataMu.Lock() defer d.metadataMu.Unlock() + if stat.Mask&linux.STATX_SIZE != 0 { + // The size needs to be changed even when + // !d.cachedMetadataAuthoritative() because d.mappings has to be + // updated. + d.updateFileSizeLocked(stat.Size) + } if !d.isSynthetic() { if stat.Mask != 0 { + if stat.Mask&linux.STATX_SIZE != 0 { + // Check whether to allow a truncate request to be made. + switch d.mode & linux.S_IFMT { + case linux.S_IFREG: + // Allow. + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } + } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -940,6 +995,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } else { atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) } + // Restore mask bits that we cleared earlier. + stat.Mask |= linux.STATX_ATIME } if setLocalMtime { if stat.Mtime.Nsec == linux.UTIME_NOW { @@ -947,48 +1004,56 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } else { atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) } + // Restore mask bits that we cleared earlier. + stat.Mask |= linux.STATX_MTIME } atomic.StoreInt64(&d.ctime, now) - if stat.Mask&linux.STATX_SIZE != 0 { + return nil +} + +// Preconditions: d.metadataMu must be locked. +func (d *dentry) updateFileSizeLocked(newSize uint64) { + d.dataMu.Lock() + oldSize := d.size + atomic.StoreUint64(&d.size, newSize) + // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings + // below. This allows concurrent calls to Read/Translate/etc. These + // functions synchronize with truncation by refusing to use cache + // contents beyond the new d.size. (We are still holding d.metadataMu, + // so we can't race with Write or another truncate.) + d.dataMu.Unlock() + if d.size < oldSize { + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(d.size) + if oldpgend != newpgend { + d.mapsMu.Lock() + d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + d.mapsMu.Unlock() + } + // We are now guaranteed that there are no translations of + // truncated pages, and can remove them from the cache. Since + // truncated pages have been removed from the remote file, they + // should be dropped without being written back. d.dataMu.Lock() - oldSize := d.size - d.size = stat.Size - // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings - // below. This allows concurrent calls to Read/Translate/etc. These - // functions synchronize with truncation by refusing to use cache - // contents beyond the new d.size. (We are still holding d.metadataMu, - // so we can't race with Write or another truncate.) + d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) + d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) d.dataMu.Unlock() - if d.size < oldSize { - oldpgend, _ := usermem.PageRoundUp(oldSize) - newpgend, _ := usermem.PageRoundUp(d.size) - if oldpgend != newpgend { - d.mapsMu.Lock() - d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - d.mapsMu.Unlock() - } - // We are now guaranteed that there are no translations of - // truncated pages, and can remove them from the cache. Since - // truncated pages have been removed from the remote file, they - // should be dropped without being written back. - d.dataMu.Lock() - d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) - d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) - d.dataMu.Unlock() - } } - return nil } func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { + return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) +} + func dentryUIDFromP9UID(uid p9.UID) uint32 { if !uid.Ok() { return uint32(auth.OverflowUID) @@ -1044,15 +1109,34 @@ func (d *dentry) decRefLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { + if d.isDir() { + events |= linux.IN_ISDIR + } + + d.fs.renameMu.RLock() + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted()) + } + d.watches.Notify("", events, cookie, et, d.isDeleted()) + d.fs.renameMu.RUnlock() +} // Watches implements vfs.DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *dentry) Watches() *vfs.Watches { - return nil + return &d.watches +} + +// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. +// +// If no watches are left on this dentry and it has no references, cache it. +func (d *dentry) OnZeroWatches() { + if atomic.LoadInt64(&d.refs) == 0 { + d.fs.renameMu.Lock() + d.checkCachingLocked() + d.fs.renameMu.Unlock() + } } // checkCachingLocked should be called after d's reference count becomes 0 or it @@ -1086,6 +1170,9 @@ func (d *dentry) checkCachingLocked() { // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { + if d.isDeleted() { + d.watches.HandleDeletion() + } if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- @@ -1094,6 +1181,14 @@ func (d *dentry) checkCachingLocked() { d.destroyLocked() return } + // If d still has inotify watches and it is not deleted or invalidated, we + // cannot cache it and allow it to be evicted. Otherwise, we will lose its + // watches, even if a new dentry is created for the same file in the future. + // Note that the size of d.watches cannot concurrently transition from zero + // to non-zero, because adding a watch requires holding a reference on d. + if d.watches.Size() > 0 { + return + } // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1199,7 +1294,7 @@ func (d *dentry) setDeleted() { // We only support xattrs prefixed with "user." (see b/148380782). Currently, // there is no need to expose any other xattrs through a gofer. func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() { + if d.file.isNil() || !d.userXattrSupported() { return nil, nil } xattrMap, err := d.file.listXattr(ctx, size) @@ -1225,6 +1320,9 @@ func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vf if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { return "", syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return "", syserror.ENODATA + } return d.file.getXattr(ctx, opts.Name, opts.Size) } @@ -1238,6 +1336,9 @@ func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vf if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { return syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return syserror.EPERM + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } @@ -1251,10 +1352,20 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { return syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return syserror.EPERM + } return d.file.removeXattr(ctx, name) } -// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDirectory(). +// Extended attributes in the user.* namespace are only supported for regular +// files and directories. +func (d *dentry) userXattrSupported() bool { + filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() + return filetype == linux.ModeRegular || filetype == linux.ModeDirectory +} + +// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir(). func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { // O_TRUNC unconditionally requires us to obtain a new handle (opened with // O_TRUNC). @@ -1346,23 +1457,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool } // incLinks increments link count. -// -// Preconditions: d.nlink != 0 && d.nlink < math.MaxUint32. func (d *dentry) incLinks() { - v := atomic.AddUint32(&d.nlink, 1) - if v < 2 { - panic(fmt.Sprintf("dentry.nlink is invalid (was 0 or overflowed): %d", v)) + if atomic.LoadUint32(&d.nlink) == 0 { + // The remote filesystem doesn't support link count. + return } + atomic.AddUint32(&d.nlink, 1) } // decLinks decrements link count. -// -// Preconditions: d.nlink > 1. func (d *dentry) decLinks() { - v := atomic.AddUint32(&d.nlink, ^uint32(0)) - if v == 0 { - panic(fmt.Sprintf("dentry.nlink must be greater than 0: %d", v)) + if atomic.LoadUint32(&d.nlink) == 0 { + // The remote filesystem doesn't support link count. + return } + atomic.AddUint32(&d.nlink, ^uint32(0)) } // fileDescription is embedded by gofer implementations of @@ -1401,7 +1510,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - return fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()) + if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil { + return err + } + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. @@ -1416,12 +1531,22 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts) + d := fd.dentry() + if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { + return err + } + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name) + d := fd.dentry() + if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { + return err + } + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. @@ -1433,9 +1558,14 @@ func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t f } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { fd.lockLogging.Do(func() { log.Infof("Range lock using gofer file handled internally.") }) - return fd.LockFD.LockPOSIX(ctx, uid, t, rng, block) + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) } diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 724a3f1f7..8792ca4f2 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -126,11 +126,16 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o } func (h *handle) sync(ctx context.Context) error { + // Handle most common case first. if h.fd >= 0 { ctx.UninterruptibleSleepStart(false) err := syscall.Fsync(int(h.fd)) ctx.UninterruptibleSleepFinish(false) return err } + if h.file.isNil() { + // File hasn't been touched, there is nothing to sync. + return nil + } return h.file.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 0d10cf7ac..09f142cfc 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -24,11 +24,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -67,12 +67,46 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { return d.handle.file.flush(ctx) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + + d := fd.dentry() + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + size := offset + length + + // Allocating a smaller size is a noop. + if size <= d.size { + return nil + } + + d.handleMu.Lock() + defer d.handleMu.Unlock() + + err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + if err != nil { + return err + } + d.dataMu.Lock() + atomic.StoreUint64(&d.size, size) + d.dataMu.Unlock() + if !d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } - if opts.Flags != 0 { + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } @@ -120,21 +154,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the fd was opened with O_APPEND, make sure the file size is updated. + // There is a possible race here if size is modified externally after + // metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } + } + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // Set offset to file size if the fd was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) + n, err := fd.pwriteLocked(ctx, src, offset, opts) + return n, offset + n, err +} +// Preconditions: fd.dentry().metatdataMu must be locked. +func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() if d.fs.opts.interop != InteropModeShared { // Compare Linux's mm/filemap.c:__generic_file_write_iter() => // file_update_time(). This is d.touchCMtime(), but without locking @@ -154,12 +220,12 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange + var freed []memmap.FileRange d.dataMu.Lock() cseg := d.cache.LowerBoundSegment(mr.Start) for cseg.Ok() && cseg.Start() < mr.End { cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) + freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) cseg = d.cache.Remove(cseg).NextSegment() } d.dataMu.Unlock() @@ -197,8 +263,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -489,15 +555,24 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() + newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence) + if err != nil { + return 0, err + } + fd.off = newOffset + return newOffset, nil +} + +// Calculate the new offset for a seek operation on a regular file. +func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int64, whence int32) (int64, error) { switch whence { case linux.SEEK_SET: // Use offset as specified. case linux.SEEK_CUR: - offset += fd.off + offset += fdOffset case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE: // Ensure file size is up to date. - d := fd.dentry() - if fd.filesystem().opts.interop == InteropModeShared { + if !d.cachedMetadataAuthoritative() { if err := d.updateFromGetattr(ctx); err != nil { return 0, err } @@ -525,7 +600,6 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( if offset < 0 { return 0, syserror.EINVAL } - fd.off = offset return offset, nil } @@ -536,20 +610,19 @@ func (fd *regularFileFD) Sync(ctx context.Context) error { func (d *dentry) syncSharedHandle(ctx context.Context) error { d.handleMu.RLock() - if !d.handleWritable { - d.handleMu.RUnlock() - return nil - } - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) - d.dataMu.Unlock() - if err == nil { - // Sync the remote file. - err = d.handle.sync(ctx) + defer d.handleMu.RUnlock() + + if d.handleWritable { + d.dataMu.Lock() + // Write dirty cached data to the remote file. + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } } - d.handleMu.RUnlock() - return err + // Sync the remote file. + return d.handle.sync(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -747,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. d.mapsMu.Lock() @@ -795,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { } } -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. +// dentryPlatformFile implements memmap.File. It exists solely because dentry +// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file // is available (i.e. dentry.handle.fd >= 0), and that FD is used for @@ -804,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { type dentryPlatformFile struct { *dentry - // fdRefs counts references on platform.File offsets. fdRefs is protected + // fdRefs counts references on memmap.File offsets. fdRefs is protected // by dentry.dataMu. fdRefs fsutil.FrameRefSet @@ -816,29 +889,29 @@ type dentryPlatformFile struct { hostFileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.IncRefAndAccount(fr) d.dataMu.Unlock() } -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) d.handleMu.RUnlock() return bs, err } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() fd := d.handle.fd diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 289efdd25..811528982 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -16,22 +16,22 @@ package gofer import ( "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) // specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device -// special files, and (when filesystemOptions.specialRegularFiles is in effect) -// regular files. specialFileFD differs from regularFileFD by using per-FD -// handles instead of shared per-dentry handles, and never buffering I/O. +// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is +// in effect) regular files. specialFileFD differs from regularFileFD by using +// per-FD handles instead of shared per-dentry handles, and never buffering I/O. type specialFileFD struct { fileDescription @@ -42,27 +42,27 @@ type specialFileFD struct { // file offset is significant, i.e. a regular file. seekable is immutable. seekable bool - // mayBlock is true if this file description represents a file for which - // queue may send I/O readiness events. mayBlock is immutable. - mayBlock bool - queue waiter.Queue + // haveQueue is true if this file description represents a file for which + // queue may send I/O readiness events. haveQueue is immutable. + haveQueue bool + queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex off int64 } -func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks, flags uint32) (*specialFileFD, error) { +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG - mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK + haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 fd := &specialFileFD{ - handle: h, - seekable: seekable, - mayBlock: mayBlock, + handle: h, + seekable: seekable, + haveQueue: haveQueue, } fd.LockFD.Init(locks) - if mayBlock && h.fd >= 0 { + if haveQueue { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err } @@ -71,7 +71,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks DenyPRead: !seekable, DenyPWrite: !seekable, }); err != nil { - if mayBlock && h.fd >= 0 { + if haveQueue { fdnotifier.RemoveFD(h.fd) } return nil, err @@ -81,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *lock.FileLocks // Release implements vfs.FileDescriptionImpl.Release. func (fd *specialFileFD) Release() { - if fd.mayBlock && fd.handle.fd >= 0 { + if fd.haveQueue { fdnotifier.RemoveFD(fd.handle.fd) } fd.handle.close(context.Background()) @@ -101,7 +101,7 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { // Readiness implements waiter.Waitable.Readiness. func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { - if fd.mayBlock { + if fd.haveQueue { return fdnotifier.NonBlockingPoll(fd.handle.fd, mask) } return fd.fileDescription.Readiness(mask) @@ -109,8 +109,9 @@ func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventRegister(e, mask) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventRegister(e, mask) @@ -118,8 +119,9 @@ func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { // EventUnregister implements waiter.Waitable.EventUnregister. func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventUnregister(e) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventUnregister(e) @@ -130,7 +132,11 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs if fd.seekable && offset < 0 { return 0, syserror.EINVAL } - if opts.Flags != 0 { + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } @@ -139,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't // hold here since specialFileFD doesn't client-cache data. Just buffer the // read instead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } buf := make([]byte, dst.NumBytes()) @@ -171,35 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if fd.seekable && offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the regular file fd was opened with O_APPEND, make sure the file size + // is updated. There is a possible race here if size is modified externally + // after metadata cache is updated. + if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } if fd.seekable { + // We need to hold the metadataMu *while* writing to a regular file. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Set offset to file size if the regular file was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) } // Do a buffered write. See rationale in PRead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d.cachedMetadataAuthoritative() { d.touchCMtime() } buf := make([]byte, src.NumBytes()) // Don't do partial writes if we get a partial read from src. if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, err + return 0, offset, err } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } - return int64(n), err + finalOff = offset + // Update file size for regular files. + if fd.seekable { + finalOff += int64(n) + // d.metadataMu is already locked at this point. + if uint64(finalOff) > d.size { + d.dataMu.Lock() + defer d.dataMu.Unlock() + atomic.StoreUint64(&d.size, uint64(finalOff)) + } + } + return int64(n), finalOff, err } // Write implements vfs.FileDescriptionImpl.Write. @@ -209,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts } fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -222,27 +269,15 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( } fd.mu.Lock() defer fd.mu.Unlock() - switch whence { - case linux.SEEK_SET: - // Use offset as given. - case linux.SEEK_CUR: - offset += fd.off - default: - // SEEK_END, SEEK_DATA, and SEEK_HOLE aren't supported since it's not - // clear that file size is even meaningful for these files. - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL + newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence) + if err != nil { + return 0, err } - fd.off = offset - return offset, nil + fd.off = newOffset + return newOffset, nil } // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - if !fd.vfsfd.IsWritable() { - return nil - } - return fd.handle.sync(ctx) + return fd.dentry().syncSharedHandle(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 1d5aa82dc..0eef4e16e 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -36,7 +36,7 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { } } -// Preconditions: fs.interop != InteropModeShared. +// Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime { return @@ -51,8 +51,8 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) { mnt.EndWrite() } -// Preconditions: fs.interop != InteropModeShared. The caller has successfully -// called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -60,8 +60,8 @@ func (d *dentry) touchCtime() { d.metadataMu.Unlock() } -// Preconditions: fs.interop != InteropModeShared. The caller has successfully -// called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCMtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -70,6 +70,8 @@ func (d *dentry) touchCMtime() { d.metadataMu.Unlock() } +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// locked d.metadataMu. func (d *dentry) touchCMtimeLocked() { now := d.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&d.mtime, now) diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 54f16ad63..bd701bbc7 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -22,24 +22,24 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 5ec5100b8..c894f2ca0 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -28,13 +28,13 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -91,7 +91,9 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) isTTY: opts.IsTTY, wouldBlock: wouldBlock(uint32(fileType)), seekable: seekable, - canMap: canMap(uint32(fileType)), + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + canMap: fileType == linux.S_IFREG, } i.pf.inode = i @@ -183,7 +185,7 @@ type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks // When the reference count reaches zero, the host fd is closed. refs.AtomicRefCount @@ -257,7 +259,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL } @@ -371,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { // SetStat implements kernfs.Inode. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - s := opts.Stat + s := &opts.Stat m := s.Mask if m == 0 { @@ -384,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Fstat(i.hostFD, &hostStat); err != nil { return err } - if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { return err } @@ -394,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } } if m&linux.STATX_SIZE != 0 { + if hostStat.Mode&linux.S_IFMT != linux.S_IFREG { + return syserror.EINVAL + } if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } @@ -457,10 +462,12 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u fileType := s.Mode & linux.FileTypeMask // Constrain flags to a subset we can handle. - // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags. - flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND + // + // TODO(gvisor.dev/issue/2601): Support O_NONBLOCK by adding RWF_NOWAIT to pread/pwrite calls. + flags &= syscall.O_ACCMODE | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND - if fileType == syscall.S_IFSOCK { + switch fileType { + case syscall.S_IFSOCK: if i.isTTY { log.Warningf("cannot use host socket fd %d as TTY", i.hostFD) return nil, syserror.ENOTTY @@ -472,30 +479,33 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u } // Currently, we only allow Unix sockets to be imported. return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks) - } - // TODO(gvisor.dev/issue/1672): Whitelist specific file types here, so that - // we don't allow importing arbitrary file types without proper support. - if i.isTTY { - fd := &TTYFileDescription{ - fileDescription: fileDescription{inode: i}, - termios: linux.DefaultSlaveTermios, + case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR: + if i.isTTY { + fd := &TTYFileDescription{ + fileDescription: fileDescription{inode: i}, + termios: linux.DefaultSlaveTermios, + } + fd.LockFD.Init(&i.locks) + vfsfd := &fd.vfsfd + if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return vfsfd, nil } + + fd := &fileDescription{inode: i} fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil - } - fd := &fileDescription{inode: i} - fd.LockFD.Init(&i.locks) - vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err + default: + log.Warningf("cannot import host fd %d with file type %o", i.hostFD, fileType) + return nil, syserror.EPERM } - return vfsfd, nil } // fileDescription is embedded by host fd implementations of FileDescriptionImpl. @@ -527,8 +537,8 @@ func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // Stat implements vfs.FileDescriptionImpl. -func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts) +func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } // Release implements vfs.FileDescriptionImpl. @@ -536,6 +546,16 @@ func (f *fileDescription) Release() { // noop } +// Allocate implements vfs.FileDescriptionImpl. +func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { + if !f.inode.seekable { + return syserror.ESPIPE + } + + // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds. + return syserror.EOPNOTSUPP +} + // PRead implements FileDescriptionImpl. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode @@ -562,7 +582,7 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts } return n, err } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. + f.offsetMu.Lock() n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags) f.offset += n @@ -571,8 +591,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts } func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // TODO(gvisor.dev/issue/1672): Support select preadv2 flags. - if flags != 0 { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) @@ -583,41 +605,58 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off // PWrite implements FileDescriptionImpl. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - i := f.inode - if !i.seekable { + if !f.inode.seekable { return 0, syserror.ESPIPE } - return writeToHostFD(ctx, i.hostFD, src, offset, opts.Flags) + return f.writeToHostFD(ctx, src, offset, opts.Flags) } // Write implements FileDescriptionImpl. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode if !i.seekable { - n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags) + n, err := f.writeToHostFD(ctx, src, -1, opts.Flags) if isBlockError(err) { err = syserror.ErrWouldBlock } return n, err } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. + f.offsetMu.Lock() - n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags) + // NOTE(gvisor.dev/issue/2983): O_APPEND may cause memory corruption if + // another process modifies the host file between retrieving the file size + // and writing to the host fd. This is an unavoidable race condition because + // we cannot enforce synchronization on the host. + if f.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + var s syscall.Stat_t + if err := syscall.Fstat(i.hostFD, &s); err != nil { + f.offsetMu.Unlock() + return 0, err + } + f.offset = s.Size + } + n, err := f.writeToHostFD(ctx, src, f.offset, opts.Flags) f.offset += n f.offsetMu.Unlock() return n, err } -func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags. +func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSequence, offset int64, flags uint32) (int64, error) { + hostFD := f.inode.hostFD + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if flags != 0 { return 0, syserror.EOPNOTSUPP } writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := src.CopyInTo(ctx, writer) hostfd.PutReadWriterAt(writer) + // NOTE(gvisor.dev/issue/2979): We always sync everything, even for O_DSYNC. + if n > 0 && f.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { + if syncErr := unix.Fsync(hostFD); syncErr != nil { + return int64(n), syncErr + } + } return int64(n), err } @@ -688,7 +727,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i // Sync implements FileDescriptionImpl. func (f *fileDescription) Sync(context.Context) error { - // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData optimization, so we always sync everything. + // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } @@ -718,3 +757,13 @@ func (f *fileDescription) EventUnregister(e *waiter.Entry) { func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 8545a82f0..65d3af38c 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -19,13 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// inodePlatformFile implements platform.File. It exists solely because inode -// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// inodePlatformFile implements memmap.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. type inodePlatformFile struct { @@ -34,7 +33,7 @@ type inodePlatformFile struct { // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex - // fdRefs counts references on platform.File offsets. It is used solely for + // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. fdRefs fsutil.FrameRefSet @@ -45,32 +44,32 @@ type inodePlatformFile struct { fileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. +// IncRef implements memmap.File.IncRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) IncRef(fr platform.FileRange) { +func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) i.fdRefsMu.Unlock() } -// DecRef implements platform.File.DecRef. +// DecRef implements memmap.File.DecRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) DecRef(fr platform.FileRange) { +func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) i.fdRefsMu.Unlock() } -// MapInternal implements platform.File.MapInternal. +// MapInternal implements memmap.File.MapInternal. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (i *inodePlatformFile) FD() int { return i.hostFD } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 38f1fbfba..fd16bd92d 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -47,11 +47,6 @@ func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transpor return ep, nil } -// maxSendBufferSize is the maximum host send buffer size allowed for endpoint. -// -// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). -const maxSendBufferSize = 8 << 20 - // ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and // transport.Receiver. It is backed by a host fd that was imported at sentry // startup. This fd is shared with a hostfs inode, which retains ownership of @@ -114,10 +109,6 @@ func (c *ConnectedEndpoint) init() *syserr.Error { if err != nil { return syserr.FromError(err) } - if sndbuf > maxSendBufferSize { - log.Warningf("Socket send buffer too large: %d", sndbuf) - return syserr.ErrInvalidEndpointState - } c.stype = linux.SockType(stype) c.sndbuf = int64(sndbuf) diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go index 584c247d2..fc0d5fd38 100644 --- a/pkg/sentry/fsimpl/host/socket_iovec.go +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -17,13 +17,10 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 68af6e5af..4ee9270cc 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -325,9 +326,9 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) task := kernel.TaskFromContext(ctx) if task == nil { // No task? Linux does not have an analog for this case, but - // tty_check_change is more of a blacklist of cases than a - // whitelist, and is surprisingly permissive. Allowing the - // change seems most appropriate. + // tty_check_change only blocks specific cases and is + // surprisingly permissive. Allowing the change seems + // appropriate. return nil } @@ -377,3 +378,13 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) return kernel.ERESTARTSYS } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 2bc757b1a..412bdb2eb 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -49,16 +49,6 @@ func wouldBlock(fileType uint32) bool { return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK } -// canMap returns true if a file with fileType is allowed to be memory mapped. -// This is ported over from VFS1, but it's probably not the best way for us -// to check if a file can be memory mapped. -func canMap(fileType uint32) bool { - // TODO(gvisor.dev/issue/1672): Also allow "special files" to be mapped (see fs/host:canMap()). - // - // TODO(b/38213152): Some obscure character devices can be mapped. - return fileType == syscall.S_IFREG -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 0299dbde9..3835557fe 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -45,11 +45,11 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", @@ -68,9 +68,8 @@ go_test( "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 6418de0a3..c6c4472e7 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -19,9 +19,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -39,7 +39,7 @@ type DynamicBytesFile struct { InodeNotDirectory InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks data vfs.DynamicBytesSource } @@ -86,7 +86,7 @@ type DynamicBytesFD struct { } // Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *lock.FileLocks, flags uint32) error { +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { fd.LockFD.Init(locks) if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { return err @@ -101,12 +101,12 @@ func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) } -// Read implmenets vfs.FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) } -// PRead implmenets vfs.FileDescriptionImpl.PRead. +// PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) } @@ -127,7 +127,7 @@ func (fd *DynamicBytesFD) Release() {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs, opts) + return fd.inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. @@ -135,3 +135,13 @@ func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error { // DynamicBytesFiles are immutable. return syserror.EPERM } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 33a5968ca..1d37ccb98 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -19,10 +19,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -57,7 +57,7 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} if err := fd.Init(children, locks, opts); err != nil { return nil, err @@ -71,7 +71,7 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre // Init initializes a GenericDirectoryFD. Use it when overriding // GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the // correct implementation. -func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *lock.FileLocks, opts *vfs.OpenOptions) error { +func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. return syserror.EISDIR @@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) } -// Release implements vfs.FileDecriptionImpl.Release. +// Release implements vfs.FileDescriptionImpl.Release. func (fd *GenericDirectoryFD) Release() {} func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { @@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode } -// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds // o.mu when calling cb. func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { fd.mu.Lock() @@ -132,7 +132,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent opts := vfs.StatOptions{Mask: linux.STATX_INO} // Handle ".". if fd.off == 0 { - stat, err := fd.inode().Stat(fd.filesystem(), opts) + stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -152,7 +152,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent if fd.off == 1 { vfsd := fd.vfsfd.VirtualDentry().Dentry() parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode - stat, err := parentInode.Stat(fd.filesystem(), opts) + stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -176,7 +176,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(fd.filesystem(), opts) + stat, err := inode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent return err } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() @@ -226,7 +226,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() inode := fd.inode() - return inode.Stat(fs, opts) + return inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. @@ -235,3 +235,18 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode return inode.SetStat(ctx, fd.filesystem(), creds, opts) } + +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 8939871c1..61a36cff9 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -684,7 +684,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } - return inode.Stat(fs.VFSFilesystem(), opts) + return inode.Stat(ctx, fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 0e4927215..579e627f0 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -244,7 +243,7 @@ func (a *InodeAttrs) Mode() linux.FileMode { // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { +func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK stat.DevMajor = a.devMajor @@ -268,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -294,6 +293,8 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut // inode numbers are immutable after node creation. // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. + // Also, STATX_SIZE will need some special handling, because read-only static + // files should return EIO for truncate operations. return nil } @@ -470,6 +471,8 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De if err := o.checkExistingLocked(name, child); err != nil { return err } + + // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. o.removeLocked(name) return nil } @@ -517,6 +520,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c if err := o.checkExistingLocked(oldname, child); err != nil { return nil, err } + + // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. replaced := dst.replaceChildLocked(newname, child) return replaced, nil } @@ -557,7 +562,7 @@ type StaticDirectory struct { InodeNoDynamicLookup OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks } var _ Inode = (*StaticDirectory)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index bbee8ccda..46f207664 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -227,16 +227,19 @@ func (d *Dentry) destroy() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +// Although Linux technically supports inotify on pseudo filesystems (inotify +// is implemented at the vfs layer), it is not particularly useful. It is left +// unimplemented until someone actually needs it. +func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *Dentry) Watches() *vfs.Watches { return nil } +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +func (d *Dentry) OnZeroWatches() {} + // InsertChild inserts child into the vfs dentry cache with the given name under // this dentry. This does not update the directory inode, so calling this on // its own isn't sufficient to insert a child into a directory. InsertChild @@ -343,7 +346,7 @@ type inodeMetadata interface { // Stat returns the metadata for this inode. This corresponds to // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) + Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) // SetStat updates the metadata for this inode. This corresponds to // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking @@ -425,10 +428,10 @@ type inodeDynamicLookup interface { // IterDirents is used to iterate over dynamically created entries. It invokes // cb on each entry in the directory represented by the FileDescription. // 'offset' is the offset for the entire IterDirents call, which may include - // results from the caller. 'relOffset' is the offset inside the entries - // returned by this IterDirents invocation. In other words, - // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff, - // while 'relOffset' is the place where iteration should start from. + // results from the caller (e.g. "." and ".."). 'relOffset' is the offset + // inside the entries returned by this IterDirents invocation. In other words, + // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as + // the return value, while 'relOffset' is the place to start iteration. IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 6749facf7..dc407eb1d 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -103,7 +102,7 @@ type readonlyDir struct { kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks dentry kernfs.Dentry } @@ -133,7 +132,7 @@ type dir struct { kernfs.InodeNoDynamicLookup kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem dentry kernfs.Dentry diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index f9413bbdd..8cf5b35d3 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -29,11 +29,11 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index 6f47167d3..f5c2462a5 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -263,3 +263,25 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in return 0, syserror.EINVAL } } + +// Sync implements vfs.FileDescriptionImpl.Sync. Forwards sync to the upper +// layer, if there is one. The lower layer doesn't need to sync because it +// never changes. +func (fd *directoryFD) Sync(ctx context.Context) error { + d := fd.dentry() + if !d.isCopiedUp() { + return nil + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + pop := vfs.PathOperation{ + Root: d.upperVD, + Start: d.upperVD, + } + upperFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY}) + if err != nil { + return err + } + err = upperFD.Sync(ctx) + upperFD.DecRef() + return err +} diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index ff82e1f20..6b705e955 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := rp.Mount() diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index a3c1f7a8d..c0749e711 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := fd.vfsfd.Mount() @@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// StatFS implements vfs.FileDesciptionImpl.StatFS. +// StatFS implements vfs.FileDescriptionImpl.StatFS. func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e660d0e2c..e720d4825 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -35,9 +35,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -415,7 +415,7 @@ type dentry struct { devMinor uint32 ino uint64 - locks lock.FileLocks + locks vfs.FileLocks } // newDentry creates a new dentry. The dentry initially has no references; it @@ -528,6 +528,11 @@ func (d *dentry) Watches() *vfs.Watches { return nil } +// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) OnZeroWatches() {} + // iterLayers invokes yield on each layer comprising d, from top to bottom. If // any call to yield returns false, iterLayer stops iteration. func (d *dentry) iterLayers(yield func(vd vfs.VirtualDentry, isUpper bool) bool) { @@ -610,3 +615,13 @@ func (fd *fileDescription) filesystem() *filesystem { func (fd *fileDescription) dentry() *dentry { return fd.vfsfd.Dentry().Impl().(*dentry) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD index c618dbe6c..5950a2d59 100644 --- a/pkg/sentry/fsimpl/pipefs/BUILD +++ b/pkg/sentry/fsimpl/pipefs/BUILD @@ -15,7 +15,6 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e4dabaa33..811f80a5f 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -82,7 +81,7 @@ type inode struct { kernfs.InodeNotSymlink kernfs.InodeNoopRefCount - locks lock.FileLocks + locks vfs.FileLocks pipe *pipe.VFSPipe ino uint64 @@ -116,7 +115,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 351ba4ee9..6014138ff 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -22,6 +22,7 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/safemem", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/inet", @@ -35,7 +36,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/usermem", diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index e2cdb7ee9..79c2725f3 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -38,7 +37,7 @@ type subtasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem task *kernel.Task @@ -129,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac return fd.GenericDirectoryFD.IterDirents(ctx, cb) } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { return 0, syserror.ENOENT @@ -166,8 +165,8 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v } // Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 44078a765..a5c7aa470 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -39,7 +38,7 @@ type taskInode struct { kernfs.InodeAttrs kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks task *kernel.Task } @@ -157,8 +156,8 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. } // Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.Inode.Stat(fs, opts) +func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index ef6c1d04f..fea29e5f0 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -54,7 +53,7 @@ func taskFDExists(t *kernel.Task, fd int32) bool { } type fdDir struct { - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem task *kernel.Task @@ -65,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDynamicLookup. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -73,7 +72,6 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, abs } }) - offset := absOffset + relOffset typ := uint8(linux.DT_REG) if i.produceSymlink { typ = linux.DT_LNK diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index e5eaa91cd..859b7d727 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -30,11 +31,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) +// "There is an (arbitrary) limit on the number of lines in the file. As at +// Linux 3.18, the limit is five lines." - user_namespaces(7) +const maxIDMapLines = 5 + // mm gets the kernel task's MemoryManager. No additional reference is taken on // mm here. This is safe because MemoryManager.destroy is required to leave the // MemoryManager in a state where it's still usable as a DynamicBytesSource. @@ -227,8 +231,9 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Linux will return envp up to and including the first NULL character, // so find it. - if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 { - buf.Truncate(end) + envStart := int(ar.Length()) + if nullIdx := bytes.IndexByte(buf.Bytes()[envStart:], 0); nullIdx != -1 { + buf.Truncate(envStart + nullIdx) } } @@ -283,7 +288,8 @@ func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}. +// idMapData implements vfs.WritableDynamicBytesSource for +// /proc/[pid]/{gid_map|uid_map}. // // +stateify savable type idMapData struct { @@ -295,7 +301,7 @@ type idMapData struct { var _ dynamicInode = (*idMapData)(nil) -// Generate implements vfs.DynamicBytesSource.Generate. +// Generate implements vfs.WritableDynamicBytesSource.Generate. func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { var entries []auth.IDMapEntry if d.gids { @@ -309,6 +315,60 @@ func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // "In addition, the number of bytes written to the file must be less than + // the system page size, and the write must be performed at the start of + // the file ..." - user_namespaces(7) + srclen := src.NumBytes() + if srclen >= usermem.PageSize || offset != 0 { + return 0, syserror.EINVAL + } + b := make([]byte, srclen) + if _, err := src.CopyIn(ctx, b); err != nil { + return 0, err + } + + // Truncate from the first NULL byte. + var nul int64 + nul = int64(bytes.IndexByte(b, 0)) + if nul == -1 { + nul = srclen + } + b = b[:nul] + // Remove the last \n. + if nul >= 1 && b[nul-1] == '\n' { + b = b[:nul-1] + } + lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1) + if len(lines) > maxIDMapLines { + return 0, syserror.EINVAL + } + + entries := make([]auth.IDMapEntry, len(lines)) + for i, l := range lines { + var e auth.IDMapEntry + _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length) + if err != nil { + return 0, syserror.EINVAL + } + entries[i] = e + } + var err error + if d.gids { + err = d.task.UserNamespace().SetGIDMap(ctx, entries) + } else { + err = d.task.UserNamespace().SetUIDMap(ctx, entries) + } + if err != nil { + return 0, err + } + + // On success, Linux's kernel/user_namespace.c:map_write() always returns + // count, even if fewer bytes were used. + return int64(srclen), nil +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -777,7 +837,7 @@ type namespaceInode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink - locks lock.FileLocks + locks vfs.FileLocks } var _ kernfs.Inode = (*namespaceInode)(nil) @@ -816,7 +876,7 @@ var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) // Stat implements FileDescriptionImpl. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(vfs, opts) + return fd.inode.Stat(ctx, vfs, opts) } // SetStat implements FileDescriptionImpl. @@ -830,3 +890,13 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release() { fd.inode.DecRef() } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 58c8b9d05..6d2b90a8b 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -44,7 +43,7 @@ type tasksInode struct { kernfs.OrderedChildren kernfs.AlwaysValid - locks lock.FileLocks + locks vfs.FileLocks fs *filesystem pidns *kernel.PIDNamespace @@ -207,8 +206,8 @@ func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs. return fd.VFSFileDescription(), nil } -func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 237f17def..1b548ccd4 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -15,7 +15,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserror", ], ) @@ -30,6 +29,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index b84463d3a..01ce30a4d 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -101,7 +100,7 @@ type dir struct { kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren - locks lock.FileLocks + locks vfs.FileLocks dentry kernfs.Dentry } @@ -139,7 +138,7 @@ type cpuFile struct { // Generate implements vfs.DynamicBytesSource.Generate. func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "0-%d", c.maxCores-1) + fmt.Fprintf(buf, "0-%d\n", c.maxCores-1) return nil } diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 4b3602d47..242d5fd12 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -51,7 +51,7 @@ func TestReadCPUFile(t *testing.T) { k := kernel.KernelFromContext(s.Ctx) maxCPUCores := k.ApplicationCores() - expected := fmt.Sprintf("0-%d", maxCPUCores-1) + expected := fmt.Sprintf("0-%d\n", maxCPUCores-1) for _, fname := range []string{"online", "possible", "present"} { pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname)) diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 0e4053a46..400a97996 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -32,6 +32,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index c16a36cdb..e743e8114 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, } @@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) { k.SetMemoryFile(mf) // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } - kernel.VFS2Enabled = true - - if err := k.VFS().Init(); err != nil { - return nil, fmt.Errorf("VFS init: %v", err) - } k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, AllowUserList: true, diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 062321cbc..e73732a6b 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -62,7 +62,6 @@ go_library( "//pkg/sentry/uniqueid", "//pkg/sentry/usage", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sentry/vfs/memxattr", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 913b8a6c5..0a1ad4765 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -79,7 +79,10 @@ func (dir *directory) removeChildLocked(child *dentry) { dir.iterMu.Lock() dir.childList.Remove(child) dir.iterMu.Unlock() - child.unlinked = true +} + +func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error { + return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid))) } type directoryFD struct { @@ -107,13 +110,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fs := fd.filesystem() dir := fd.inode().impl.(*directory) + defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + // fs.mu is required to read d.parent and dentry.name. fs.mu.RLock() defer fs.mu.RUnlock() dir.iterMu.Lock() defer dir.iterMu.Unlock() - fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) fd.inode().touchAtime(fd.vfsfd.Mount()) if fd.off == 0 { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 72399b321..ef210a69b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -79,7 +79,7 @@ afterSymlink: } if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // Symlink traversal updates access time. - atomic.StoreInt64(&d.inode.atime, d.inode.fs.clock.Now().Nanoseconds()) + child.inode.touchAtime(rp.Mount()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -182,7 +182,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if dir { ev |= linux.IN_ISDIR } - parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent) + parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) parentDir.inode.touchCMtime() return nil } @@ -237,18 +237,22 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EXDEV } d := vd.Dentry().Impl().(*dentry) - if d.inode.isDir() { + i := d.inode + if i.isDir() { return syserror.EPERM } - if d.inode.nlink == 0 { + if err := vfs.MayLink(auth.CredentialsFromContext(ctx), linux.FileMode(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + return err + } + if i.nlink == 0 { return syserror.ENOENT } - if d.inode.nlink == maxLinks { + if i.nlink == maxLinks { return syserror.EMLINK } - d.inode.incLinksLocked() - d.inode.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent) - parentDir.insertChildLocked(fs.newDentry(d.inode), name) + i.incLinksLocked() + i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) + parentDir.insertChildLocked(fs.newDentry(i), name) return nil }) } @@ -273,7 +277,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { - case 0, linux.S_IFREG: + case linux.S_IFREG: childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFIFO: childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) @@ -364,10 +368,13 @@ afterTrailingSymlink: if err != nil { return nil, err } - parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent) + parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) parentDir.inode.touchCMtime() return fd, nil } + if mustCreate { + return nil, syserror.EEXIST + } // Is the file mounted over? if err := rp.CheckMount(&child.vfsd); err != nil { return nil, err @@ -375,7 +382,7 @@ afterTrailingSymlink: // Do we need to resolve a trailing symlink? if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // Symlink traversal updates access time. - atomic.StoreInt64(&child.inode.atime, child.inode.fs.clock.Now().Nanoseconds()) + child.inode.touchAtime(rp.Mount()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -400,10 +407,10 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open case *regularFile: var fd regularFileFD fd.LockFD.Init(&d.inode.locks) - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil { return nil, err } - if opts.Flags&linux.O_TRUNC != 0 { + if !afterCreate && opts.Flags&linux.O_TRUNC != 0 { if _, err := impl.truncate(0); err != nil { return nil, err } @@ -416,7 +423,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } var fd directoryFD fd.LockFD.Init(&d.inode.locks) - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil { return nil, err } return &fd.vfsfd, nil @@ -485,6 +492,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if !ok { return syserror.ENOENT } + if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil { + return err + } // Note that we don't need to call rp.CheckMount(), since if renamed is a // mount point then we want to rename the mount point, not anything in the // mounted filesystem. @@ -599,6 +609,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if !ok { return syserror.ENOENT } + if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { + return err + } childDir, ok := child.inode.impl.(*directory) if !ok { return syserror.ENOTDIR @@ -618,7 +631,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } parentDir.removeChildLocked(child) - parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent) + parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) // Remove links for child, child/., and child/.. child.inode.decLinksLocked() child.inode.decLinksLocked() @@ -631,14 +644,16 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } - if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ev, 0, vfs.InodeEvent) @@ -707,6 +722,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if !ok { return syserror.ENOENT } + if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { + return err + } if child.inode.isDir() { return syserror.EISDIR } @@ -781,14 +799,16 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil @@ -797,14 +817,16 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } if err := d.inode.removexattr(rp.Credentials(), name); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 77447b32c..abbaa5d60 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -274,11 +274,35 @@ func (fd *regularFileFD) Release() { // noop } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + f := fd.inode().impl.(*regularFile) + + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + oldSize := f.size + size := offset + length + if oldSize >= size { + return nil + } + _, err := f.truncateLocked(size) + return err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } + + // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since + // all state is in-memory. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { + return 0, syserror.EOPNOTSUPP + } + if dst.NumBytes() == 0 { return 0, nil } @@ -301,40 +325,60 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } + + // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since + // all state is in-memory. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + srclen := src.NumBytes() if srclen == 0 { - return 0, nil + return 0, offset, nil } f := fd.inode().impl.(*regularFile) + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + // If the file is opened with O_APPEND, update offset to file size. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking f.inode.mu is sufficient for reading f.size. + offset = int64(f.size) + } if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - var err error srclen, err = vfs.CheckLimit(ctx, offset, srclen) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(srclen) - f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) - fd.inode().touchCMtimeLocked() - f.inode.mu.Unlock() + f.inode.touchCMtimeLocked() putRegularFileReadWriter(rw) - return n, err + return n, n + offset, err } // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.offMu.Unlock() return n, err } @@ -360,11 +404,6 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( return offset, nil } -// Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *regularFileFD) Sync(ctx context.Context) error { - return nil -} - // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go index 64e1c40ad..146c7fdfe 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go @@ -138,48 +138,37 @@ func TestLocks(t *testing.T) { } defer cleanup() - var ( - uid1 lock.UniqueID - uid2 lock.UniqueID - // Non-blocking. - block lock.Blocker - ) - - uid1 = 123 - uid2 = 456 - - if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, block); err != nil { + uid1 := 123 + uid2 := 456 + if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, block); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want) } if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil { t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err) } - if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block); err != nil { + if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil { t.Fatalf("fd.Impl().LockBSD failed: err = %v", err) } - rng1 := lock.LockRange{0, 1} - rng2 := lock.LockRange{1, 2} - - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, rng1, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng2, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, rng1, block); err != nil { + if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil { t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err) } - if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng1, block), syserror.ErrWouldBlock; got != want { + if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want { t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want) } - if err := fd.Impl().UnlockPOSIX(ctx, uid1, rng1); err != nil { + if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil { t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err) } } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 71a7522af..2545d88e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -36,11 +36,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -215,11 +215,6 @@ type dentry struct { // filesystem.mu. name string - // unlinked indicates whether this dentry has been unlinked from its parent. - // It is only set to true on an unlink operation, and never set from true to - // false. unlinked is protected by filesystem.mu. - unlinked bool - // dentryEntry (ugh) links dentries into their parent directory.childList. dentryEntry @@ -259,18 +254,22 @@ func (d *dentry) DecRef() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { if d.inode.isDir() { events |= linux.IN_ISDIR } + // tmpfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates + // that d was deleted. + deleted := d.vfsd.IsDead() + + d.inode.fs.mu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - // Note that d.parent or d.name may be stale if there is a concurrent - // rename operation. Inotify does not provide consistency guarantees. - d.parent.inode.watches.NotifyWithExclusions(d.name, events, cookie, et, d.unlinked) + d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted) } - d.inode.watches.Notify("", events, cookie, et) + d.inode.watches.Notify("", events, cookie, et, deleted) + d.inode.fs.mu.RUnlock() } // Watches implements vfs.DentryImpl.Watches. @@ -278,6 +277,9 @@ func (d *dentry) Watches() *vfs.Watches { return &d.inode.watches } +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +func (d *dentry) OnZeroWatches() {} + // inode represents a filesystem object. type inode struct { // fs is the owning filesystem. fs is immutable. @@ -310,7 +312,7 @@ type inode struct { ctime int64 // nanoseconds mtime int64 // nanoseconds - locks lock.FileLocks + locks vfs.FileLocks // Inotify watches for this inode. watches vfs.Watches @@ -336,7 +338,6 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.ctime = now i.mtime = now // i.nlink initialized by caller - i.watches = vfs.Watches{} i.impl = impl } @@ -451,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) { } } -func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error { +func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -459,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { return err } i.mu.Lock() @@ -694,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) d := fd.dentry() - if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, creds, &opts); err != nil { return err } @@ -761,9 +763,26 @@ func NewMemfd(mount *vfs.Mount, creds *auth.Credentials, allowSeals bool, name s // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with // FMODE_READ | FMODE_WRITE. var fd regularFileFD + fd.Init(&inode.locks) flags := uint32(linux.O_RDWR) if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + +// Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all +// filesystem state is in-memory. +func (*fileDescription) Sync(context.Context) error { + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a28eab8b8..f6886a758 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -85,6 +85,7 @@ go_library( name = "kernel", srcs = [ "abstract_socket_namespace.go", + "aio.go", "context.go", "fd_table.go", "fd_table_unsafe.go", @@ -131,6 +132,7 @@ go_library( "task_stop.go", "task_syscall.go", "task_usermem.go", + "task_work.go", "thread_group.go", "threads.go", "timekeeper.go", @@ -199,6 +201,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/state", "//pkg/state/statefile", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go new file mode 100644 index 000000000..0ac78c0b8 --- /dev/null +++ b/pkg/sentry/kernel/aio.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "time" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" +) + +// AIOCallback is an function that does asynchronous I/O on behalf of a task. +type AIOCallback func(context.Context) + +// QueueAIO queues an AIOCallback which will be run asynchronously. +func (t *Task) QueueAIO(cb AIOCallback) { + ctx := taskAsyncContext{t: t} + wg := &t.TaskSet().aioGoroutines + wg.Add(1) + go func() { + cb(ctx) + wg.Done() + }() +} + +type taskAsyncContext struct { + context.NoopSleeper + t *Task +} + +// Debugf implements log.Logger.Debugf. +func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { + ctx.t.Debugf(format, v...) +} + +// Infof implements log.Logger.Infof. +func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { + ctx.t.Infof(format, v...) +} + +// Warningf implements log.Logger.Warningf. +func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { + ctx.t.Warningf(format, v...) +} + +// IsLogging implements log.Logger.IsLogging. +func (ctx taskAsyncContext) IsLogging(level log.Level) bool { + return ctx.t.IsLogging(level) +} + +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + +// Value implements context.Context.Value. +func (ctx taskAsyncContext) Value(key interface{}) interface{} { + return ctx.t.Value(key) +} diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index 0c40bf315..dd5f0f5fa 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -18,7 +18,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" ) // contextID is the kernel package's type for context.Context.Value keys. @@ -113,55 +112,3 @@ func (*Task) Done() <-chan struct{} { func (*Task) Err() error { return nil } - -// AsyncContext returns a context.Context that may be used by goroutines that -// do work on behalf of t and therefore share its contextual values, but are -// not t's task goroutine (e.g. asynchronous I/O). -func (t *Task) AsyncContext() context.Context { - return taskAsyncContext{t: t} -} - -type taskAsyncContext struct { - context.NoopSleeper - t *Task -} - -// Debugf implements log.Logger.Debugf. -func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { - ctx.t.Debugf(format, v...) -} - -// Infof implements log.Logger.Infof. -func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { - ctx.t.Infof(format, v...) -} - -// Warningf implements log.Logger.Warningf. -func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { - ctx.t.Warningf(format, v...) -} - -// IsLogging implements log.Logger.IsLogging. -func (ctx taskAsyncContext) IsLogging(level log.Level) bool { - return ctx.t.IsLogging(level) -} - -// Deadline implements context.Context.Deadline. -func (ctx taskAsyncContext) Deadline() (time.Time, bool) { - return ctx.t.Deadline() -} - -// Done implements context.Context.Done. -func (ctx taskAsyncContext) Done() <-chan struct{} { - return ctx.t.Done() -} - -// Err implements context.Context.Err. -func (ctx taskAsyncContext) Err() error { - return ctx.t.Err() -} - -// Value implements context.Context.Value. -func (ctx taskAsyncContext) Value(key interface{}) interface{} { - return ctx.t.Value(key) -} diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 3d78cd48f..4c0f1e41f 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -107,7 +107,7 @@ type EventPoll struct { // different lock to avoid circular lock acquisition order involving // the wait queue mutexes and mu. The full order is mu, observed file // wait queue mutex, then listsMu; this allows listsMu to be acquired - // when readyCallback is called. + // when (*pollEntry).Callback is called. // // An entry is always in one of the following lists: // readyList -- when there's a chance that it's ready to have @@ -116,7 +116,7 @@ type EventPoll struct { // readEvents() functions always call the entry's file // Readiness() function to confirm it's ready. // waitingList -- when there's no chance that the entry is ready, - // so it's waiting for the readyCallback to be called + // so it's waiting for the (*pollEntry).Callback to be called // on it before it gets moved to the readyList. // disabledList -- when the entry is disabled. This happens when // a one-shot entry gets delivered via readEvents(). @@ -269,21 +269,19 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { return ret } -// readyCallback is called when one of the files we're polling becomes ready. It -// moves said file to the readyList if it's currently in the waiting list. -type readyCallback struct{} - // Callback implements waiter.EntryCallback.Callback. -func (*readyCallback) Callback(w *waiter.Entry) { - entry := w.Context.(*pollEntry) - e := entry.epoll +// +// Callback is called when one of the files we're polling becomes ready. It +// moves said file to the readyList if it's currently in the waiting list. +func (p *pollEntry) Callback(*waiter.Entry) { + e := p.epoll e.listsMu.Lock() - if entry.curList == &e.waitingList { - e.waitingList.Remove(entry) - e.readyList.PushBack(entry) - entry.curList = &e.readyList + if p.curList == &e.waitingList { + e.waitingList.Remove(p) + e.readyList.PushBack(p) + p.curList = &e.readyList e.listsMu.Unlock() e.Notify(waiter.EventIn) @@ -310,7 +308,7 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { // Check if the file happens to already be in a ready state. ready := f.Readiness(entry.mask) & entry.mask if ready != 0 { - (*readyCallback).Callback(nil, &entry.waiter) + entry.Callback(&entry.waiter) } } @@ -380,10 +378,9 @@ func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.Ev userData: data, epoll: e, flags: flags, - waiter: waiter.Entry{Callback: &readyCallback{}}, mask: mask, } - entry.waiter.Context = entry + entry.waiter.Callback = entry e.files[id] = entry entry.file = refs.NewWeakRef(id.File, entry) @@ -406,7 +403,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter } // Unregister the old mask and remove entry from the list it's in, so - // readyCallback is guaranteed to not be called on this entry anymore. + // (*pollEntry).Callback is guaranteed to not be called on this entry anymore. entry.id.File.EventUnregister(&entry.waiter) // Remove entry from whatever list it's in. This ensure that no other diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go index 8e9f200d0..7c61e0258 100644 --- a/pkg/sentry/kernel/epoll/epoll_state.go +++ b/pkg/sentry/kernel/epoll/epoll_state.go @@ -21,8 +21,7 @@ import ( // afterLoad is invoked by stateify. func (p *pollEntry) afterLoad() { - p.waiter = waiter.Entry{Callback: &readyCallback{}} - p.waiter.Context = p + p.waiter.Callback = p p.file = refs.NewWeakRef(p.id.File, p) p.id.File.EventRegister(&p.waiter, p.mask) } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index b9126e946..2b3955598 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -11,6 +11,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", "//pkg/sync", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index d32c3e90a..153d2cd9b 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -20,15 +20,21 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new FileAsync. +// New creates a new fs.FileAsync. func New() fs.FileAsync { return &FileAsync{} } +// NewVFS2 creates a new vfs.FileAsync. +func NewVFS2() vfs.FileAsync { + return &FileAsync{} +} + // FileAsync sends signals when the registered file is ready for IO. // // +stateify savable @@ -170,3 +176,13 @@ func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kern a.recipientTG = nil a.recipientPG = recipient } + +// ClearOwner unsets the current signal recipient. +func (a *FileAsync) ClearOwner() { + a.mu.Lock() + defer a.mu.Unlock() + a.requester = nil + a.recipientT = nil + a.recipientTG = nil + a.recipientPG = nil +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index b35afafe3..4b7d234a4 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" ) // FDFlags define flags for an individual descriptor. @@ -148,7 +149,12 @@ func (f *FDTable) drop(file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(file *vfs.FileDescription) { - // TODO(gvisor.dev/issue/1480): Release locks. + // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the + // entire file. + err := file.UnlockPOSIX(context.Background(), f, 0, 0, linux.SEEK_SET) + if err != nil && err != syserror.ENOLCK { + panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + } // Generate inotify events. ev := uint32(linux.IN_CLOSE_NOWRITE) @@ -157,7 +163,7 @@ func (f *FDTable) dropVFS2(file *vfs.FileDescription) { } file.Dentry().InotifyWithParent(ev, 0, vfs.PathEvent) - // Drop the table reference. + // Drop the table's reference. file.DecRef() } @@ -458,6 +464,29 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { return nil } +// SetFlagsVFS2 sets the flags for the given file descriptor. +// +// True is returned iff flags were changed. +func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { + if fd < 0 { + // Don't accept negative FDs. + return syscall.EBADF + } + + f.mu.Lock() + defer f.mu.Unlock() + + file, _, _ := f.getVFS2(fd) + if file == nil { + // No file found. + return syscall.EBADF + } + + // Update the flags. + f.setVFS2(fd, file, flags) + return nil +} + // Get returns a reference to the file and the flags for the FD or nil if no // file is defined for the given fd. // diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 732e66da4..bcc1b29a8 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3 } } -// UnlockPI unlock the futex following the Priority-inheritance futex -// rules. The address provided must contain the caller's TID. If there are -// waiters, TID of the next waiter (FIFO) is set to the given address, and the -// waiter woken up. If there are no waiters, 0 is set to the address. +// UnlockPI unlocks the futex following the Priority-inheritance futex rules. +// The address provided must contain the caller's TID. If there are waiters, +// TID of the next waiter (FIFO) is set to the given address, and the waiter +// woken up. If there are no waiters, 0 is set to the address. func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error { k, err := getKey(t, addr, private) if err != nil { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index bcbeb6a39..15dae0f5b 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -34,7 +34,6 @@ package kernel import ( "errors" "fmt" - "io" "path/filepath" "sync/atomic" "time" @@ -73,6 +72,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -81,6 +81,10 @@ import ( // easy access everywhere. To be removed once VFS2 becomes the default. var VFS2Enabled = false +// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// easy access everywhere. To be removed once FUSE is completed. +var FUSEEnabled = false + // Kernel represents an emulated Linux kernel. It must be initialized by calling // Init() or LoadFrom(). // @@ -417,7 +421,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w io.Writer) error { +func (k *Kernel) SaveTo(w wire.Writer) error { saveStart := time.Now() ctx := k.SupervisorContext() @@ -452,9 +456,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { return err } - // Ensure that all pending asynchronous work is complete: - // - inode and mount release - // - asynchronuous IO + // Ensure that all inode and mount release operations have completed. fs.AsyncBarrier() // Once all fs work has completed (flushed references have all been released), @@ -475,18 +477,18 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { + if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { + stats, err := state.Save(k.SupervisorContext(), w, k) + if err != nil { return err } - log.Infof("Kernel save stats: %s", &stats) + log.Infof("Kernel save stats: %s", stats.String()) log.Infof("Kernel save took [%s].", time.Since(kernelStart)) // Save the memory file's state. @@ -631,7 +633,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -642,7 +644,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { + if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -657,11 +659,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { + stats, err := state.Load(k.SupervisorContext(), r, k) + if err != nil { return err } - log.Infof("Kernel load stats: %s", &stats) + log.Infof("Kernel load stats: %s", stats.String()) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) // rootNetworkNamespace should be populated after loading the state file. @@ -892,7 +894,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if mntnsVFS2 == nil { // MountNamespaceVFS2 adds a reference to the namespace, which is // transferred to the new process. - mntnsVFS2 = k.GlobalInit().Leader().MountNamespaceVFS2() + mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2() } // Get the root directory from the MountNamespace. root := args.MountNamespaceVFS2.Root() @@ -1249,13 +1251,15 @@ func (k *Kernel) Kill(es ExitStatus) { } // Pause requests that all tasks in k temporarily stop executing, and blocks -// until all tasks in k have stopped. Multiple calls to Pause nest and require -// an equal number of calls to Unpause to resume execution. +// until all tasks and asynchronous I/O operations in k have stopped. Multiple +// calls to Pause nest and require an equal number of calls to Unpause to +// resume execution. func (k *Kernel) Pause() { k.extMu.Lock() k.tasks.BeginExternalStop() k.extMu.Unlock() k.tasks.runningGoroutines.Wait() + k.tasks.aioGoroutines.Wait() } // Unpause ends the effect of a previous call to Pause. If Unpause is called @@ -1465,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 { return now } +// AfterFunc implements tcpip.Clock.AfterFunc. +func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(k.realtimeClock, d, f) +} + // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 0db546b98..449643118 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -26,8 +26,8 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index c0e9ee1f4..45d4c5fc1 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -20,8 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -63,12 +63,12 @@ func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { // Preconditions: statusFlags should not contain an open access mode. func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { // Connected pipes share the same locks. - locks := &lock.FileLocks{} + locks := &vfs.FileLocks{} return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } // Open opens the pipe represented by vp. -func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) (*vfs.FileDescription, error) { +func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() defer vp.mu.Unlock() @@ -130,7 +130,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *lock.FileLocks) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription { fd := &VFSPipeFD{ pipe: &vp.pipe, } @@ -200,6 +200,11 @@ func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask { } } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *VFSPipeFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ESPIPE +} + // EventRegister implements waiter.Waitable.EventRegister. func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { fd.pipe.EventRegister(e, mask) @@ -451,3 +456,13 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } return n, err } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index bfd779837..c211fc8d0 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -20,7 +20,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f66cfcc7f..55b4c2cdb 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -45,7 +45,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -370,7 +369,7 @@ type Shm struct { // fr is the offset into mfp.MemoryFile() that backs this contents of this // segment. Immutable. - fr platform.FileRange + fr memmap.FileRange // mu protects all fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 4607cde2f..a83ce219c 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -98,6 +98,15 @@ func (s *syslog) Log() []byte { s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...) } + if VFS2Enabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...) + if FUSEEnabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...) + } + } + time += rand.Float64() / 2 s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f48247c94..c4db05bd8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -68,6 +68,21 @@ type Task struct { // runState is exclusive to the task goroutine. runState taskRunState + // taskWorkCount represents the current size of the task work queue. It is + // used to avoid acquiring taskWorkMu when the queue is empty. + // + // Must accessed with atomic memory operations. + taskWorkCount int32 + + // taskWorkMu protects taskWork. + taskWorkMu sync.Mutex `state:"nosave"` + + // taskWork is a queue of work to be executed before resuming user execution. + // It is similar to the task_work mechanism in Linux. + // + // taskWork is exclusive to the task goroutine. + taskWork []TaskWorker + // haveSyscallReturn is true if tc.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // @@ -550,6 +565,10 @@ type Task struct { // futexWaiter is exclusive to the task goroutine. futexWaiter *futex.Waiter `state:"nosave"` + // robustList is a pointer to the head of the tasks's robust futex + // list. + robustList usermem.Addr + // startTime is the real time at which the task started. It is set when // a Task is created or invokes execve(2). // diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 9b69f3cbe..7803b98d0 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { return flags.CloseOnExec }) + // Handle the robust futex list. + t.exitRobustList() + // NOTE(b/30815691): We currently do not implement privileged // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c4ade6e8e..231ac548a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState { } } + // Handle the robust futex list. + t.exitRobustList() + // Deactivate the address space and update max RSS before releasing the // task's MM. t.Deactivate() diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index a53e77c9f..4b535c949 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -15,6 +15,7 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) { func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) { return t.MemoryManager().GetSharedFutexKey(t, addr) } + +// GetRobustList sets the robust futex list for the task. +func (t *Task) GetRobustList() usermem.Addr { + t.mu.Lock() + addr := t.robustList + t.mu.Unlock() + return addr +} + +// SetRobustList sets the robust futex list for the task. +func (t *Task) SetRobustList(addr usermem.Addr) { + t.mu.Lock() + t.robustList = addr + t.mu.Unlock() +} + +// exitRobustList walks the robust futex list, marking locks dead and notifying +// wakers. It corresponds to Linux's exit_robust_list(). Following Linux, +// errors are silently ignored. +func (t *Task) exitRobustList() { + t.mu.Lock() + addr := t.robustList + t.robustList = 0 + t.mu.Unlock() + + if addr == 0 { + return + } + + var rl linux.RobustListHead + if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil { + return + } + + next := rl.List + done := 0 + var pendingLockAddr usermem.Addr + if rl.ListOpPending != 0 { + pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset) + } + + // Wake up normal elements. + for usermem.Addr(next) != addr { + // We traverse to the next element of the list before we + // actually wake anything. This prevents the race where waking + // this futex causes a modification of the list. + thisLockAddr := usermem.Addr(next + rl.FutexOffset) + + // Try to decode the next element in the list before waking the + // current futex. But don't check the error until after we've + // woken the current futex. Linux does it in this order too + _, nextErr := t.CopyIn(usermem.Addr(next), &next) + + // Wakeup the current futex if it's not pending. + if thisLockAddr != pendingLockAddr { + t.wakeRobustListOne(thisLockAddr) + } + + // If there was an error copying the next futex, we must bail. + if nextErr != nil { + break + } + + // This is a user structure, so it could be a massive list, or + // even contain a loop if they are trying to mess with us. We + // cap traversal to prevent that. + done++ + if done >= linux.ROBUST_LIST_LIMIT { + break + } + } + + // Is there a pending entry to wake? + if pendingLockAddr != 0 { + t.wakeRobustListOne(pendingLockAddr) + } +} + +// wakeRobustListOne wakes a single futex from the robust list. +func (t *Task) wakeRobustListOne(addr usermem.Addr) { + // Bit 0 in address signals PI futex. + pi := addr&1 == 1 + addr = addr &^ 1 + + // Load the futex. + f, err := t.LoadUint32(addr) + if err != nil { + // Can't read this single value? Ignore the problem. + // We can wake the other futexes in the list. + return + } + + tid := uint32(t.ThreadID()) + for { + // Is this held by someone else? + if f&linux.FUTEX_TID_MASK != tid { + return + } + + // This thread is dying and it's holding this futex. We need to + // set the owner died bit and wake up any waiters. + newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED + if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil { + return + } else if curF != f { + // Futex changed out from under us. Try again... + f = curF + continue + } + + // Wake waiters if there are any. + if f&linux.FUTEX_WAITERS != 0 { + private := f&linux.FUTEX_PRIVATE_FLAG != 0 + if pi { + t.Futex().UnlockPI(t, addr, tid, private) + return + } + t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1) + } + + // Done. + return + } +} diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index d654dd997..7d4f44caf 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState { return (*runInterrupt)(nil) } - // We're about to switch to the application again. If there's still a + // Execute any task work callbacks before returning to user space. + if atomic.LoadInt32(&t.taskWorkCount) > 0 { + t.taskWorkMu.Lock() + queue := t.taskWork + t.taskWork = nil + atomic.StoreInt32(&t.taskWorkCount, 0) + t.taskWorkMu.Unlock() + + // Do not hold taskWorkMu while executing task work, which may register + // more work. + for _, work := range queue { + work.TaskWork(t) + } + } + + // We're about to switch to the application again. If there's still an // unhandled SyscallRestartErrno that wasn't translated to an EINTR, // restart the syscall that was interrupted. If there's a saved signal // mask, restore it. (Note that restoring the saved signal mask may unblock diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go new file mode 100644 index 000000000..dda5a433a --- /dev/null +++ b/pkg/sentry/kernel/task_work.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync/atomic" + +// TaskWorker is a deferred task. +// +// This must be savable. +type TaskWorker interface { + // TaskWork will be executed prior to returning to user space. Note that + // TaskWork may call RegisterWork again, but this will not be executed until + // the next return to user space, unlike in Linux. This effectively allows + // registration of indefinite user return hooks, but not by default. + TaskWork(t *Task) +} + +// RegisterWork can be used to register additional task work that will be +// performed prior to returning to user space. See TaskWorker.TaskWork for +// semantics regarding registration. +func (t *Task) RegisterWork(work TaskWorker) { + t.taskWorkMu.Lock() + defer t.taskWorkMu.Unlock() + atomic.AddInt32(&t.taskWorkCount, 1) + t.taskWork = append(t.taskWork, work) +} diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 52849f5b3..4dfd2c990 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -366,7 +366,8 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { // terminal is stolen, and all processes that had it as controlling // terminal lose it." - tty_ioctl(4) if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { - if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + // Stealing requires CAP_SYS_ADMIN in the root user namespace. + if creds := auth.CredentialsFromContext(tg.leader); !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) || arg != 1 { return syserror.EPERM } // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index bf2dabb6e..872e1a82d 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -87,6 +87,13 @@ type TaskSet struct { // at time of save (but note that this is not necessarily the same thing as // sync.WaitGroup's zero value). runningGoroutines sync.WaitGroup `state:"nosave"` + + // aioGoroutines is the number of goroutines running async I/O + // callbacks. + // + // aioGoroutines is not saved but is required to be zero at the time of + // save. + aioGoroutines sync.WaitGroup `state:"nosave"` } // newTaskSet returns a new, empty TaskSet. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 7ba7dc50c..2817aa3ba 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -6,6 +6,7 @@ go_library( name = "time", srcs = [ "context.go", + "tcpip.go", "time.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go new file mode 100644 index 000000000..c4474c0cf --- /dev/null +++ b/pkg/sentry/kernel/time/tcpip.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package time + +import ( + "sync" + "time" +) + +// TcpipAfterFunc waits for duration to elapse according to clock then runs fn. +// The timer is started immediately and will fire exactly once. +func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer { + timer := &TcpipTimer{ + clock: clock, + } + timer.notifier = functionNotifier{ + fn: func() { + // tcpip.Timer.Stop() explicitly states that the function is called in a + // separate goroutine that Stop() does not synchronize with. + // Timer.Destroy() synchronizes with calls to TimerListener.Notify(). + // This is semantically meaningful because, in the former case, it's + // legal to call tcpip.Timer.Stop() while holding locks that may also be + // taken by the function, but this isn't so in the latter case. Most + // immediately, Timer calls TimerListener.Notify() while holding + // Timer.mu. A deadlock occurs without spawning a goroutine: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock! + // + // Spawning a goroutine avoids the deadlock: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() <- Launches T2 + // T2: + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, blocks + // T1: + // => (returns) <- Timer.mu.Unlock() called + // T2: + // => (continues) <- No deadlock! + go func() { + timer.Stop() + fn() + }() + }, + } + timer.Reset(duration) + return timer +} + +// TcpipTimer is a resettable timer with variable duration expirations. +// Implements tcpip.Timer, which does not define a Destroy method; instead, all +// resources are released after timer expiration and calls to Timer.Stop. +// +// Must be created by AfterFunc. +type TcpipTimer struct { + // clock is the time source. clock is immutable. + clock Clock + + // notifier is called when the Timer expires. notifier is immutable. + notifier functionNotifier + + // mu protects t. + mu sync.Mutex + + // t stores the latest running Timer. This is replaced whenever Reset is + // called since Timer cannot be restarted once it has been Destroyed by Stop. + // + // This field is nil iff Stop has been called. + t *Timer +} + +// Stop implements tcpip.Timer.Stop. +func (r *TcpipTimer) Stop() bool { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + return false + } + _, lastSetting := r.t.Swap(Setting{}) + r.t.Destroy() + r.t = nil + return lastSetting.Enabled +} + +// Reset implements tcpip.Timer.Reset. +func (r *TcpipTimer) Reset(d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + r.t = NewTimer(r.clock, &r.notifier) + } + + r.t.Swap(Setting{ + Enabled: true, + Period: 0, + Next: r.clock.Now().Add(d), + }) +} + +// functionNotifier is a TimerListener that runs a function. +// +// functionNotifier cannot be saved or loaded. +type functionNotifier struct { + fn func() +} + +// Notify implements ktime.TimerListener.Notify. +func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) { + f.fn() + return Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (f *functionNotifier) Destroy() {} diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 0adf25691..7c4fefb16 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/log" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -90,7 +90,7 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) { +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { return &Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), }, nil @@ -210,9 +210,6 @@ func (t *Timekeeper) startUpdater() { p.realtimeBaseRef = int64(realtimeParams.BaseRef) p.realtimeFrequency = realtimeParams.Frequency } - - log.Debugf("Updating VDSO parameters: %+v", p) - return p }); err != nil { log.Warningf("Unable to update VDSO parameter page: %v", err) diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index f1b3c212c..290c32466 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -58,7 +58,7 @@ type vdsoParams struct { type VDSOParamPage struct { // The parameter page is fr, allocated from mfp.MemoryFile(). mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange // seq is the current sequence count written to the page. // @@ -81,7 +81,7 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. -func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage { +func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { return &VDSOParamPage{mfp: mfp, fr: fr} } diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index c6aa65f28..34bdb0b69 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -30,9 +30,6 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsbridge", "//pkg/sentry/kernel/auth", "//pkg/sentry/limits", @@ -45,6 +42,5 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", - "//pkg/waiter", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 616fafa2c..ddeaff3db 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -90,14 +90,23 @@ type elfInfo struct { sharedObject bool } +// fullReader interface extracts the ReadFull method from fsbridge.File so that +// client code does not need to define an entire fsbridge.File when only read +// functionality is needed. +// +// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap +// vfs.FileDescription's PRead/Read instead. +type fullReader interface { + // ReadFull is the same as fsbridge.File.ReadFull. + ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) +} + // parseHeader parse the ELF header, verifying that this is a supported ELF // file and returning the ELF program headers. // // This is similar to elf.NewFile, except that it is more strict about what it // accepts from the ELF, and it doesn't parse unnecessary parts of the file. -// -// ctx may be nil if f does not need it. -func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) { +func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // Check ident first; it will tell us the endianness of the rest of the // structs. var ident [elf.EI_NIDENT]byte diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 88449fe95..986c7fb4d 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -80,22 +79,6 @@ type LoadArgs struct { Features *cpuid.FeatureSet } -// readFull behaves like io.ReadFull for an *fs.File. -func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var total int64 - for dst.NumBytes() > 0 { - n, err := f.Preadv(ctx, dst, offset+total) - total += n - if err == io.EOF && total != 0 { - return total, io.ErrUnexpectedEOF - } else if err != nil { - return total, err - } - dst = dst.DropFirst64(n) - } - return total, nil -} - // openPath opens args.Filename and checks that it is valid for loading. // // openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not @@ -238,14 +221,14 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // Load the executable itself. loaded, ac, file, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer file.DecRef() // Load the VDSO. vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 165869028..05a294fe6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -26,10 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -37,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) const vdsoPrelink = 0xffffffffff700000 @@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} { } } -// byteReader implements fs.FileOperations for reading from a []byte source. -type byteReader struct { - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` - +type byteFullReader struct { data []byte } -var _ fs.FileOperations = (*byteReader)(nil) - -// newByteReaderFile creates a fake file to read data from. -// -// TODO(gvisor.dev/issue/2921): Convert to VFS2. -func newByteReaderFile(ctx context.Context, data []byte) *fs.File { - // Create a fake inode. - inode := fs.NewInode( - ctx, - &fsutil.SimpleFileInode{}, - fs.NewPseudoMountSource(ctx), - fs.StableAttr{ - Type: fs.Anonymous, - DeviceID: anon.PseudoDevice.DeviceID(), - InodeID: anon.PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, - }) - - // Use the fake inode to create a fake dirent. - dirent := fs.NewTransientDirent(inode) - defer dirent.DecRef() - - // Use the fake dirent to make a fake file. - flags := fs.FileFlags{Read: true, Pread: true} - return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{ - data: data, - }) -} - -func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { +func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } @@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ return int64(n), err } -func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { - panic("Write not supported") -} - // validateVDSO checks that the VDSO can be loaded by loadVDSO. // // VDSOs are special (see below). Since we are going to map the VDSO directly @@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq // * PT_LOAD segments don't extend beyond the end of the file. // // ctx may be nil if f does not need it. -func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) { +func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) { info, err := parseHeader(ctx, f) if err != nil { log.Infof("Unable to parse VDSO header: %v", err) @@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) { // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. -func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) { - vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin)) +func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { + vdsoFile := &byteFullReader{data: vdsoBin} // First make sure the VDSO is valid. vdsoFile does not use ctx, so a // nil context can be passed. info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin))) - vdsoFile.DecRef() if err != nil { return nil, err } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index a98b66de1..2c95669cd 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -28,9 +28,21 @@ go_template_instance( }, ) +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "memmap", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + go_library( name = "memmap", srcs = [ + "file_range.go", "mappable_range.go", "mapping_set.go", "mapping_set_impl.go", @@ -40,7 +52,7 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/sentry/platform", + "//pkg/safemem", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c6db9fc8f..c188f6c29 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,12 +19,12 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 -// offsets to (platform.File, uint64 File offset) pairs. +// offsets to (File, uint64 File offset) pairs. // // See mm/mm.go for Mappable's place in the lock order. // @@ -74,7 +74,7 @@ type Mappable interface { // Translations are valid until invalidated by a callback to // MappingSpace.Invalidate or until the caller removes its mapping of the // translated range. Mappable implementations must ensure that at least one - // reference is held on all pages in a platform.File that may be the result + // reference is held on all pages in a File that may be the result // of a valid Translation. // // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). @@ -100,7 +100,7 @@ type Translation struct { Source MappableRange // File is the mapped file. - File platform.File + File File // Offset is the offset into File at which this Translation begins. Offset uint64 @@ -110,9 +110,9 @@ type Translation struct { Perms usermem.AccessType } -// FileRange returns the platform.FileRange represented by t. -func (t Translation) FileRange() platform.FileRange { - return platform.FileRange{t.Offset, t.Offset + t.Source.Length()} +// FileRange returns the FileRange represented by t. +func (t Translation) FileRange() FileRange { + return FileRange{t.Offset, t.Offset + t.Source.Length()} } // CheckTranslateResult returns an error if (ts, terr) does not satisfy all @@ -361,3 +361,49 @@ type MMapOpts struct { // TODO(jamieliu): Replace entirely with MappingIdentity? Hint string } + +// File represents a host file that may be mapped into an platform.AddressSpace. +type File interface { + // All pages in a File are reference-counted. + + // IncRef increments the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) + IncRef(fr FileRange) + + // DecRef decrements the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. + DecRef(fr FileRange) + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: fr.Length() > 0. At least one reference must be held on + // all pages in fr. + // + // Postconditions: The returned mapping is valid as long as at least one + // reference is held on the mapped pages. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // FD returns the file descriptor represented by the File. + // + // The only permitted operation on the returned file descriptor is to map + // pages from it consistent with the requirements of AddressSpace.MapFile. + FD() int +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a036ce53c..f9d0837a1 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "mm", prefix = "fileRefcount", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "int32", "Functions": "fileRefcountSetFunctions", }, diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 379148903..1999ec706 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -243,7 +242,7 @@ type aioMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange } var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 6db7c3d40..3e85964e4 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -25,7 +25,7 @@ // Locks taken by memmap.Mappable.Translate // mm.privateRefs.mu // platform.AddressSpace locks -// platform.File locks +// memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu // @@ -396,7 +396,7 @@ type pma struct { // file is the file mapped by this pma. Only pmas for which file == // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to // the corresponding file range while they exist. - file platform.File `state:"nosave"` + file memmap.File `state:"nosave"` // off is the offset into file at which this pma begins. // @@ -436,7 +436,7 @@ type pma struct { private bool // If internalMappings is not empty, it is the cached return value of - // file.MapInternal for the platform.FileRange mapped by this pma. + // file.MapInternal for the memmap.FileRange mapped by this pma. internalMappings safemem.BlockSeq `state:"nosave"` } @@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 { func (fileRefcountSetFunctions) ClearValue(_ *int32) { } -func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) { +func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) { return rc1, rc1 == rc2 } -func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) { +func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) { return rc, rc } diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 62e4c20af..930ec895f 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat } } -// Pin returns the platform.File ranges currently mapped by addresses in ar in +// Pin returns the memmap.File ranges currently mapped by addresses in ar in // mm, acquiring a reference on the returned ranges which the caller must // release by calling Unpin. If not all addresses are mapped, Pin returns a // non-nil error. Note that Pin may return both a non-empty slice of @@ -674,15 +673,15 @@ type PinnedRange struct { Source usermem.AddrRange // File is the mapped file. - File platform.File + File memmap.File // Offset is the offset into File at which this PinnedRange begins. Offset uint64 } -// FileRange returns the platform.File offsets mapped by pr. -func (pr PinnedRange) FileRange() platform.FileRange { - return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} +// FileRange returns the memmap.File offsets mapped by pr. +func (pr PinnedRange) FileRange() memmap.FileRange { + return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} } // Unpin releases the reference held by prs. @@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf } // incPrivateRef acquires a reference on private pages in fr. -func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { +func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) { mm.privateRefs.mu.Lock() defer mm.privateRefs.mu.Unlock() refSet := &mm.privateRefs.refs @@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { } // decPrivateRef releases a reference on private pages in fr. -func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) { - var freed []platform.FileRange +func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { + var freed []memmap.FileRange mm.privateRefs.mu.Lock() refSet := &mm.privateRefs.refs @@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa // Discard internal mappings instead of trying to merge them, since merging // them requires an allocation and getting them again from the - // platform.File might not. + // memmap.File might not. pma1.internalMappings = safemem.BlockSeq{} return pma1, true } @@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error { return nil } -func (pseg pmaIterator) fileRange() platform.FileRange { +func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } // Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { +func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") @@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { pma := pseg.ValuePtr() pstart := pseg.Start() - return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} + return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 9ad52082d..0e142fb11 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -35,7 +34,7 @@ type SpecialMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange name string } @@ -44,7 +43,7 @@ type SpecialMappable struct { // SpecialMappable will use the given name in /proc/[pid]/maps. // // Preconditions: fr.Length() != 0. -func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable { +func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} m.EnableLeakCheck("mm.SpecialMappable") return &m @@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider { // FileRange returns the offsets into MemoryFileProvider().MemoryFile() that // store the SpecialMappable's contents. -func (m *SpecialMappable) FileRange() platform.FileRange { +func (m *SpecialMappable) FileRange() memmap.FileRange { return m.fr } diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9836ba71..7a3311a70 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -36,14 +36,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "usage", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "usageInfo", "Functions": "usageSetFunctions", }, @@ -56,14 +56,14 @@ go_template_instance( "minDegree": "10", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "reclaim", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "reclaimSetValue", "Functions": "reclaimSetFunctions", }, @@ -89,9 +89,10 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/hostmm", - "//pkg/sentry/platform", + "//pkg/sentry/memmap", "//pkg/sentry/usage", "//pkg/state", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 46f19d218..3243d7214 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -33,14 +33,14 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// MemoryFile is a platform.File whose pages may be allocated to arbitrary +// MemoryFile is a memmap.File whose pages may be allocated to arbitrary // users. type MemoryFile struct { // opts holds options passed to NewMemoryFile. opts is immutable. @@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() { // to Allocate. // // Preconditions: length must be page-aligned and non-zero. -func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { +func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { if length == 0 || length%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return platform.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, syserror.ENOMEM } // Expand the file if needed. @@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Round the new file size up to be chunk-aligned. newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask if err := f.file.Truncate(newFileSize); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } f.fileSize = newFileSize f.mappingsMu.Lock() @@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi bs[i] = 0 } }); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } } if !f.usage.Add(fr, usageInfo{ @@ -439,56 +439,64 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // space for mappings to be allocated downwards. // // Precondition: alignment must be a power of 2. -func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) { alignmentMask := alignment - 1 - for gap := usage.UpperBoundGap(uint64(fileSize)); gap.Ok(); gap = gap.PrevLargeEnoughGap(length) { - // Start searching only at end of file. + + // Search for space in existing gaps, starting at the current end of the + // file and working backward. + lastGap := usage.LastGap() + gap := lastGap + for { end := gap.End() if end > uint64(fileSize) { end = uint64(fileSize) } - // Start at the top and align downwards. - start := end - length - if start > end { - break // Underflow. + // Try to allocate from the end of this gap, with the start of the + // allocated range aligned down to alignment. + unalignedStart := end - length + if unalignedStart > end { + // Negative overflow: this and all preceding gaps are too small to + // accommodate length. + break } - start &^= alignmentMask - - // Is the gap still sufficient? - if start < gap.Start() { - continue + if start := unalignedStart &^ alignmentMask; start >= gap.Start() { + return memmap.FileRange{start, start + length}, true } - // Allocate in the given gap. - return platform.FileRange{start, start + length}, true + gap = gap.PrevLargeEnoughGap(length) + if !gap.Ok() { + break + } } // Check that it's possible to fit this allocation at the end of a file of any size. - min := usage.LastGap().Start() + min := lastGap.Start() min = (min + alignmentMask) &^ alignmentMask if min+length < min { - // Overflow. - return platform.FileRange{}, false + // Overflow: allocation would exceed the range of uint64. + return memmap.FileRange{}, false } // Determine the minimum file size required to fit this allocation at its end. for { - if fileSize >= 2*fileSize { - // Is this because it's initially empty? - if fileSize == 0 { - fileSize += chunkSize - } else { - // fileSize overflow. - return platform.FileRange{}, false + newFileSize := 2 * fileSize + if newFileSize <= fileSize { + if fileSize != 0 { + // Overflow: allocation would exceed the range of int64. + return memmap.FileRange{}, false } - } else { - // Double the current fileSize. - fileSize *= 2 + newFileSize = chunkSize + } + fileSize = newFileSize + + unalignedStart := uint64(fileSize) - length + if unalignedStart > uint64(fileSize) { + // Negative overflow: fileSize is still inadequate. + continue } - start := (uint64(fileSize) - length) &^ alignmentMask - if start >= min { - return platform.FileRange{start, start + length}, true + if start := unalignedStart &^ alignmentMask; start >= min { + return memmap.FileRange{start, start + length}, true } } } @@ -500,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // by r.ReadToBlocks(), it returns that error. // // Preconditions: length > 0. length must be page-aligned. -func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) { +func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } dsts, err := f.MapInternal(fr, usermem.Write) if err != nil { f.DecRef(fr) - return platform.FileRange{}, err + return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) un := uint64(usermem.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. - f.DecRef(platform.FileRange{fr.Start + un, fr.End}) + f.DecRef(memmap.FileRange{fr.Start + un, fr.End}) fr.End = fr.Start + un } return fr, err @@ -532,7 +540,7 @@ const ( // will read zeroes. // // Preconditions: fr.Length() > 0. -func (f *MemoryFile) Decommit(fr platform.FileRange) error { +func (f *MemoryFile) Decommit(fr memmap.FileRange) error { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -552,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error { return nil } -func (f *MemoryFile) markDecommitted(fr platform.FileRange) { +func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() // Since we're changing the knownCommitted attribute, we need to merge @@ -573,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) { f.usage.MergeRange(fr) } -// IncRef implements platform.File.IncRef. -func (f *MemoryFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (f *MemoryFile) IncRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -592,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) } -// DecRef implements platform.File.DecRef. -func (f *MemoryFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (f *MemoryFile) DecRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -629,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } } -// MapInternal implements platform.File.MapInternal. -func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -656,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) ( // forEachMappingSlice invokes fn on a sequence of byte slices that // collectively map all bytes in fr. -func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { +func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error { mappings := f.mappings.Load().([]uintptr) for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { chunk := int(chunkStart >> chunkShift) @@ -936,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( continue case !populated && populatedRun: // Finish the run by changing this segment. - runRange := platform.FileRange{ + runRange := memmap.FileRange{ Start: r.Start + uint64(populatedRunStart*usermem.PageSize), End: r.Start + uint64(i*usermem.PageSize), } @@ -1001,7 +1009,7 @@ func (f *MemoryFile) File() *os.File { return f.file } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (f *MemoryFile) FD() int { return int(f.file.Fd()) } @@ -1082,13 +1090,13 @@ func (f *MemoryFile) runReclaim() { // // Note that there returned range will be removed from tracking. It // must be reclaimed (removed from f.usage) at this point. -func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { +func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() for { for { if f.destroyed { - return platform.FileRange{}, false + return memmap.FileRange{}, false } if f.reclaimable { break @@ -1112,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } } -func (f *MemoryFile) markReclaimed(fr platform.FileRange) { +func (f *MemoryFile) markReclaimed(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) @@ -1214,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 { func (usageSetFunctions) ClearValue(val *usageInfo) { } -func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { +func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) { return val1, val1 == val2 } -func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { +func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { return val, val } @@ -1262,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 { func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { } -func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { +func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { return reclaimSetValue{}, true } -func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { +func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { return reclaimSetValue{}, reclaimSetValue{} } diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go index b5b68eb52..405db141f 100644 --- a/pkg/sentry/pgalloc/pgalloc_test.go +++ b/pkg/sentry/pgalloc/pgalloc_test.go @@ -143,6 +143,14 @@ func TestFindUnallocatedRange(t *testing.T) { start: hugepage, }, { + desc: "Allocation doubles file size more than once if necessary", + usage: &usageSegmentDataSlices{}, + fileSize: page, + length: 4 * page, + alignment: page, + start: 0, + }, + { desc: "Allocations are compact if possible", usage: &usageSegmentDataSlices{ Start: []uint64{page, 3 * page}, diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index f8385c146..78317fa35 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -26,11 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/usermem" ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -79,10 +80,10 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // Save metadata. - if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { + if _, err := state.Save(ctx, w, &f.fileSize); err != nil { return err } - if err := state.Save(ctx, w, &f.usage, nil); err != nil { + if _, err := state.Save(ctx, w, &f.usage); err != nil { return err } @@ -115,9 +116,9 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error { // Load metadata. - if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { + if _, err := state.Load(ctx, r, &f.fileSize); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -125,7 +126,7 @@ func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(ctx, r, &f.usage, nil); err != nil { + if _, err := state.Load(ctx, r, &f.usage); err != nil { return err } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 453241eca..209b28053 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,39 +1,21 @@ load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - go_library( name = "platform", srcs = [ "context.go", - "file_range.go", "mmap_min_addr.go", "platform.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/atomicbitops", "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", + "//pkg/sentry/memmap", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4792454c4..b5d27a72a 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", @@ -60,6 +61,7 @@ go_library( go_test( name = "kvm_test", srcs = [ + "kvm_amd64_test.go", "kvm_test.go", "virtual_map_test.go", ], diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index faf1d5e1c..98a3e539d 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0d1e83e6c..03a98512e 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -17,6 +17,7 @@ package kvm import ( + "syscall" "unsafe" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -60,3 +61,27 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { func getHypercallID(addr uintptr) int { return _KVM_HYPERCALL_MAX } + +// bluepillStopGuest is reponsible for injecting interrupt. +// +//go:nosplit +func bluepillStopGuest(c *vCPU) { + // Interrupt: we must have requested an interrupt + // window; set the interrupt line. + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_INTERRUPT, + uintptr(unsafe.Pointer(&bounce))); errno != 0 { + throw("interrupt injection failed") + } + // Clear previous injection request. + c.runData.requestInterruptWindow = 0 +} + +// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. +// +//go:nosplit +func bluepillReadyStopGuest(c *vCPU) bool { + return c.runData.readyForInterruptInjection != 0 +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 83643c602..dba563160 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -26,6 +26,17 @@ import ( var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL + + // vcpuSErr is the event of system error. + vcpuSErr = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 0, + pad: [6]uint8{0, 0, 0, 0, 0, 0}, + sErrEsr: 1, + }, + rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } ) // bluepillArchEnter is called during bluepillEnter. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index abd36f973..8b64f3a1e 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,6 +17,7 @@ package kvm import ( + "syscall" "unsafe" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -74,3 +75,23 @@ func getHypercallID(addr uintptr) int { return int(((addr) - arm64HypercallMMIOBase) >> 3) } } + +// bluepillStopGuest is reponsible for injecting sError. +// +//go:nosplit +func bluepillStopGuest(c *vCPU) { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { + throw("sErr injection failed") + } +} + +// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection. +// +//go:nosplit +func bluepillReadyStopGuest(c *vCPU) bool { + return true +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index a5b9be36d..bf357de1a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -133,12 +133,12 @@ func bluepillHandler(context unsafe.Pointer) { // PIC, we can't inject an interrupt while they are // masked. We need to request a window if it's not // ready. - if c.runData.readyForInterruptInjection == 0 { - c.runData.requestInterruptWindow = 1 - continue // Rerun vCPU. - } else { + if bluepillReadyStopGuest(c) { // Force injection below; the vCPU is ready. c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN + } else { + c.runData.requestInterruptWindow = 1 + continue // Rerun vCPU. } case syscall.EFAULT: // If a fault is not serviceable due to the host @@ -217,17 +217,7 @@ func bluepillHandler(context unsafe.Pointer) { } } case _KVM_EXIT_IRQ_WINDOW_OPEN: - // Interrupt: we must have requested an interrupt - // window; set the interrupt line. - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_INTERRUPT, - uintptr(unsafe.Pointer(&bounce))); errno != 0 { - throw("interrupt injection failed") - } - // Clear previous injection request. - c.runData.requestInterruptWindow = 0 + bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: c.die(bluepillArchContext(context), "shutdown") return diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go new file mode 100644 index 000000000..c0b4fd374 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +func TestSegments(t *testing.T) { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + testutil.SetTestSegments(regs) + for { + var si arch.SignalInfo + if _, err := c.SwitchToUser(ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: dummyFPState, + PageTables: pt, + FullRestore: true, + }, &si); err == platform.ErrContextInterrupt { + continue // Retry. + } else if err != nil { + t.Errorf("application segment check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestSegments(regs); err != nil { + t.Errorf("application segment check with full restore failed: %v", err) + } + break // Done. + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 3134a076b..0b06a923a 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -46,6 +46,18 @@ type userRegs struct { fpRegs userFpsimdState } +type exception struct { + sErrPending uint8 + sErrHasEsr uint8 + pad [6]uint8 + sErrEsr uint64 +} + +type kvmVcpuEvents struct { + exception + rsvd [12]uint32 +} + // updateGlobalOnce does global initialization. It has to be called only once. func updateGlobalOnce(fd int) error { physicalInit() diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 6a7676468..3bf918446 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -35,6 +35,8 @@ const ( _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 _KVM_SET_CPUID2 = 0x4008ae90 _KVM_SET_SIGNAL_MASK = 0x4004ae8b + _KVM_GET_VCPU_EVENTS = 0x8040ae9f + _KVM_SET_VCPU_EVENTS = 0x4040aea0 ) // KVM exit reasons. @@ -54,8 +56,10 @@ const ( // KVM capability options. const ( - _KVM_CAP_MAX_VCPUS = 0x42 - _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 + _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 + _KVM_CAP_VCPU_EVENTS = 0x29 + _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e ) // KVM limits. diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6c8f4fa28..45b3180f1 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) { }) } -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - func TestBounce(t *testing.T) { applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 48c834499..3de309c1a 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -78,19 +78,6 @@ func (c *vCPU) initArchState() error { return err } - // sctlr_el1 - regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.getOneRegister(®Get); err != nil { - return err - } - - dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) - data = dataGet - reg.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.setOneRegister(®); err != nil { - return err - } - // tcr_el1 data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS reg.id = _KVM_ARM64_REGS_TCR_EL1 @@ -273,8 +260,8 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) - case 0xaa: - return usermem.NoAccess, nil + case ring0.Vector(bounce): // ring0.VirtualizationException + return usermem.NoAccess, platform.ErrContextInterrupt default: return usermem.NoAccess, platform.ErrContextSignal } diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 0bebee852..07658144e 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -104,3 +104,9 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 TWIDDLE_REGS() SVC RET // never reached + +TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 + TWIDDLE_REGS() + // Branch to Register branches unconditionally to an address in <Rn>. + JMP (R4) // <=> br x4, must fault + RET // never reached diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 171513f3f..4b13eec30 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,9 +22,9 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -207,7 +207,7 @@ type AddressSpace interface { // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. // at.Any() == true. At least one reference must be held on all pages in // fr, and must continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error + MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // @@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string { return fmt.Sprintf("segmentation fault at %#x", f.Addr) } -// File represents a host file that may be mapped into an AddressSpace. -type File interface { - // All pages in a File are reference-counted. - - // IncRef increments the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) - IncRef(fr FileRange) - - // DecRef decrements the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. - DecRef(fr FileRange) - - // MapInternal returns a mapping of the given file offsets in the invoking - // process' address space for reading and writing. - // - // Note that fr.Start and fr.End need not be page-aligned. - // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. - // - // Postconditions: The returned mapping is valid as long as at least one - // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) - - // FD returns the file descriptor represented by the File. - // - // The only permitted operation on the returned file descriptor is to map - // pages from it consistent with the requirements of AddressSpace.MapFile. - FD() int -} - -// FileRange represents a range of uint64 offsets into a File. -// -// type FileRange <generated using go_generics> - -// String implements fmt.Stringer.String. -func (fr FileRange) String() string { - return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) -} - // Requirements is used to specify platform specific requirements. type Requirements struct { // RequiresCurrentPIDNS indicates that the sandbox has to be started in the diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 30402c2df..29fd23cc3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostcpu", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 2389423b0..c990f3454 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { var flags int if precommit { flags |= syscall.MAP_POPULATE diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2bc5f3ecd..6ed73699b 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -40,6 +40,14 @@ #define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) +// sctlr_el1: system control register el1. +#define SCTLR_M 1 << 0 +#define SCTLR_C 1 << 2 +#define SCTLR_I 1 << 12 +#define SCTLR_UCT 1 << 15 + +#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT) + // Saves a register set. // // This is a macro because it may need to executed in contents where a stack is @@ -496,6 +504,11 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 IRQ_DISABLE + + // Init. + MOVD $SCTLR_EL1_DEFAULT, R1 + MSR R1, SCTLR_EL1 + MOVD R8, RSV_REG ORR $0xffff000000000000, RSV_REG, RSV_REG WORD $0xd518d092 //MSR R18, TPIDR_EL1 diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index ccacaea6b..fca3a5478 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(UserFlagsClear) regs.Pstate |= UserFlagsSet + + SetTLS(regs.TPIDR_EL0) + kernelExitToEl0() + + regs.TPIDR_EL0 = GetTLS() + vector = c.vecCode // Perform the switch. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 60c9896fc..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/hostfd", "//pkg/sentry/inet", @@ -34,12 +35,13 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index c11e82c10..532a1ea5d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,12 +321,12 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } - // Whitelist options and constrain option length. + // Only allow known and safe options. optlen := getSockOptLen(t, level, name) switch level { case linux.SOL_IP: @@ -364,12 +366,13 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { - // Whitelist options and constrain option length. + // Only allow known and safe options. optlen := setSockOptLen(t, level, name) switch level { case linux.SOL_IP: @@ -415,7 +418,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Whitelist flags. + // Only allow known and safe flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the @@ -537,7 +540,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements socket.Socket.SendMsg. func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { - // Whitelist flags. + // Only allow known and safe flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument } @@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) - socket.RegisterProviderVFS2(family, &socketProviderVFS2{}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 027add1fd..8a1d52ebf 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -21,12 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -61,7 +61,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in fd: fd, }, } - s.LockFD.Init(&lock.FileLocks{}) + s.LockFD.Init(&vfs.FileLocks{}) if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -71,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in DenyPWrite: true, UseDentryMetadata: true, }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) return nil, syserr.FromError(err) } return vfsfd, nil @@ -96,7 +97,12 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// PRead implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + +// PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -134,6 +140,16 @@ func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return int64(n), err } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + type socketProviderVFS2 struct { family int } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 66015e2bc..a9f0604ae 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -41,19 +41,6 @@ const errorTargetName = "ERROR" // change the destination port/destination IP for packets. const redirectTargetName = "REDIRECT" -// Metadata is used to verify that we are correctly serializing and -// deserializing iptables into structs consumable by the iptables tool. We save -// a metadata struct when the tables are written, and when they are read out we -// verify that certain fields are the same. -// -// metadata is used by this serialization/deserializing code, not netstack. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 -} - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -79,33 +66,17 @@ func nflog(format string, args ...interface{}) { func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, info.Name) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - nflog("%v", err) + nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - nflog("returning info: %+v", info) - return info, nil } @@ -113,28 +84,18 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, userEntries.Name) - if err != nil { - nflog("%v", err) - return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if meta != table.Metadata().(metadata) { - panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) - } if binary.Size(entries) > uintptr(outLen) { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -143,44 +104,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - table, ok := stk.IPTables().GetTable(tablename.String()) - if !ok { - return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) - } - return table, nil -} - -// FillIPTablesMetadata populates stack's IPTables with metadata. -func FillIPTablesMetadata(stk *stack.Stack) { - stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) - } - table.SetMetadata(metadata) - tables[name] = table - } - }) -} - // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { - // Return values. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + var entries linux.KernelIPTGetEntries - var meta metadata + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename) + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - copy(entries.Name[:], tablename) + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) @@ -189,20 +132,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - meta.HookEntry[hook] = entries.Size + info.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - meta.Underflow[underflow] = entries.Size + info.Underflow[underflow] = entries.Size } } // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -210,20 +153,20 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -235,8 +178,8 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -245,18 +188,18 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + info.NumEntries++ } - nflog("convert to binary: finished with an marshalled size of %d", meta.Size) - meta.Size = entries.Size - return entries, meta, nil + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } func marshalTarget(target stack.Target) []byte { @@ -399,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -488,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx @@ -513,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(stack.UserChainTarget) - if !ok { + if _, ok := rule.Target.(stack.UserChainTarget); !ok { continue } @@ -530,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } - table.UserChains[target.Name] = ruleIdx + 1 } // Set each jump to point to the appropriate rule. Right now they hold byte @@ -556,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue + } if !isUnconditionalAccept(table.Rules[ruleIdx]) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -569,15 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - table.SetMetadata(metadata{ - HookEntry: replace.HookEntry, - Underflow: replace.Underflow, - NumEntries: replace.NumEntries, - Size: replace.Size, - }) - stk.IPTables().ReplaceTable(replace.Name.String(), table) - - return nil + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 84abe8d29..b91ba3ab3 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -30,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 420e573c9..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -29,13 +30,14 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index 8bfee5193..dbcd8b49a 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -18,12 +18,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -78,7 +78,7 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV sendBufferSize: defaultSendBufferSize, }, } - fd.LockFD.Init(&lock.FileLocks{}) + fd.LockFD.Init(&vfs.FileLocks{}) return fd, nil } @@ -140,3 +140,13 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0f592ecc3..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", @@ -37,7 +38,6 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 738277391..44b3fff46 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -61,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -191,6 +194,8 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -295,8 +300,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -445,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -893,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -903,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -939,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -954,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -964,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -997,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1008,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1018,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1033,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1049,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1065,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1076,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1087,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1095,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1110,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1121,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1132,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1146,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1154,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1166,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1175,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1186,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1197,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1208,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1219,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1231,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1243,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1255,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1267,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1281,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1316,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1328,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1340,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1351,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1363,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1372,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1383,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1391,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1416,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1425,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1438,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1454,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1468,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1479,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1504,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1515,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) @@ -1719,6 +1802,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument @@ -1733,6 +1824,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -2092,13 +2188,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -2419,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2474,6 +2596,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2708,11 +2835,16 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2720,9 +2852,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2741,9 +2871,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2752,52 +2881,49 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2810,9 +2936,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2826,9 +2951,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2854,7 +2978,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2877,21 +3001,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2900,7 +3031,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2911,32 +3042,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2953,6 +3084,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2962,7 +3101,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -2999,9 +3138,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 1412a4810..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -19,18 +19,20 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -66,7 +68,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu protocol: protocol, }, } - s.LockFD.Init(&lock.FileLocks{}) + s.LockFD.Init(&vfs.FileLocks{}) vfsfd := &s.vfsfd if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -318,3 +320,13 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 9b44c2b89..67737ae87 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,10 +15,11 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } @@ -314,7 +325,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. 0, // Udp/SndbufErrors. - 0, // Udp/InCsumErrors. + udp.ChecksumErrors.Value(), // Udp/InCsumErrors. 0, // Udp/IgnoredMulti. } default: @@ -366,11 +377,6 @@ func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillIPTablesMetadata populates stack's IPTables with metadata. -func (s *Stack) FillIPTablesMetadata() { - netfilter.FillIPTablesMetadata(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 6580bd6e9..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -407,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { linux.SO_MARK, linux.SO_MAX_PACING_RATE, linux.SO_NOFCS, - linux.SO_NO_CHECK, linux.SO_OOBINLINE, linux.SO_PASSCRED, linux.SO_PASSSEC, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 7d4cc80fe..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", @@ -29,11 +30,11 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", - "//pkg/sentry/vfs/lock", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 09c6d3b27..a1e49cc57 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -476,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + if e.Connected() { return linux.SS_CONNECTED } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 8c32371a2..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -26,12 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -53,7 +54,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) - fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &lock.FileLocks{}) + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) if err != nil { return nil, syserr.FromError(err) } @@ -62,7 +63,7 @@ func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) // NewFileDescription creates and returns a socket file description // corresponding to the given mount and dentry. -func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *lock.FileLocks) (*vfs.FileDescription, error) { +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) { // You can create AF_UNIX, SOCK_RAW sockets. They're the same as // SOCK_DGRAM and don't require CAP_NET_RAW. if stype == linux.SOCK_RAW { @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } @@ -300,6 +301,16 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + // providerVFS2 is a unix domain socket provider for VFS2. type providerVFS2 struct{} diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ea4f9b1a7..80c65164a 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{ 270: syscalls.Supported("pselect", Pselect), 271: syscalls.Supported("ppoll", Ppoll), 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 273: syscalls.Supported("set_robust_list", SetRobustList), + 274: syscalls.Supported("get_robust_list", GetRobustList), 275: syscalls.Supported("splice", Splice), 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index d781d6a04..ba2557c52 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -15,8 +15,8 @@ package linux import ( - "encoding/binary" - + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -27,59 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// I/O commands. -const ( - _IOCB_CMD_PREAD = 0 - _IOCB_CMD_PWRITE = 1 - _IOCB_CMD_FSYNC = 2 - _IOCB_CMD_FDSYNC = 3 - _IOCB_CMD_NOOP = 6 - _IOCB_CMD_PREADV = 7 - _IOCB_CMD_PWRITEV = 8 -) - -// I/O flags. -const ( - _IOCB_FLAG_RESFD = 1 -) - -// ioCallback describes an I/O request. -// -// The priority field is currently ignored in the implementation below. Also -// note that the IOCB_FLAG_RESFD feature is not supported. -type ioCallback struct { - Data uint64 - Key uint32 - Reserved1 uint32 - - OpCode uint16 - ReqPrio int16 - FD int32 - - Buf uint64 - Bytes uint64 - Offset int64 - - Reserved2 uint64 - Flags uint32 - - // eventfd to signal if IOCB_FLAG_RESFD is set in flags. - ResFD int32 -} - -// ioEvent describes an I/O result. -// -// +stateify savable -type ioEvent struct { - Data uint64 - Obj uint64 - Result int64 - Result2 int64 -} - -// ioEventSize is the size of an ioEvent encoded. -var ioEventSize = binary.Size(ioEvent{}) - // IoSetup implements linux syscall io_setup(2). func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { nrEvents := args[0].Int() @@ -192,7 +139,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } } - ev := v.(*ioEvent) + ev := v.(*linux.IOEvent) // Copy out the result. if _, err := t.CopyOut(eventsAddr, ev); err != nil { @@ -204,7 +151,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S } // Keep rolling. - eventsAddr += usermem.Addr(ioEventSize) + eventsAddr += usermem.Addr(linux.IOEventSize) } // Everything finished. @@ -231,7 +178,7 @@ func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadl } // memoryFor returns appropriate memory for the given callback. -func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) { +func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) { bytes := int(cb.Bytes) if bytes < 0 { // Linux also requires that this field fit in ssize_t. @@ -242,17 +189,17 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) { // we have no guarantee that t's AddressSpace will be active during the // I/O. switch cb.OpCode { - case _IOCB_CMD_PREAD, _IOCB_CMD_PWRITE: + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE: return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) - case _IOCB_CMD_PREADV, _IOCB_CMD_PWRITEV: + case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV: return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ AddressSpaceActive: false, }) - case _IOCB_CMD_FSYNC, _IOCB_CMD_FDSYNC, _IOCB_CMD_NOOP: + case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP: return usermem.IOSequence{}, nil default: @@ -261,54 +208,62 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) { } } -func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) { - if ctx.Dead() { - ctx.CancelPendingRequest() - return - } - ev := &ioEvent{ - Data: cb.Data, - Obj: uint64(cbAddr), - } +// IoCancel implements linux syscall io_cancel(2). +// +// It is not presently supported (ENOSYS indicates no support on this +// architecture). +func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + return 0, nil, syserror.ENOSYS +} - // Construct a context.Context that will not be interrupted if t is - // interrupted. - c := t.AsyncContext() +// LINT.IfChange - var err error - switch cb.OpCode { - case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV: - ev.Result, err = file.Preadv(c, ioseq, cb.Offset) - case _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV: - ev.Result, err = file.Pwritev(c, ioseq, cb.Offset) - case _IOCB_CMD_FSYNC: - err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncAll) - case _IOCB_CMD_FDSYNC: - err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncData) - } +func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback { + return func(ctx context.Context) { + if actx.Dead() { + actx.CancelPendingRequest() + return + } + ev := &linux.IOEvent{ + Data: cb.Data, + Obj: uint64(cbAddr), + } - // Update the result. - if err != nil { - err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file) - ev.Result = -int64(kernel.ExtractErrno(err, 0)) - } + var err error + switch cb.OpCode { + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV: + ev.Result, err = file.Preadv(ctx, ioseq, cb.Offset) + case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: + ev.Result, err = file.Pwritev(ctx, ioseq, cb.Offset) + case linux.IOCB_CMD_FSYNC: + err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll) + case linux.IOCB_CMD_FDSYNC: + err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncData) + } + + // Update the result. + if err != nil { + err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file) + ev.Result = -int64(kernel.ExtractErrno(err, 0)) + } - file.DecRef() + file.DecRef() - // Queue the result for delivery. - ctx.FinishRequest(ev) + // Queue the result for delivery. + actx.FinishRequest(ev) - // Notify the event file if one was specified. This needs to happen - // *after* queueing the result to avoid racing with the thread we may - // wake up. - if eventFile != nil { - eventFile.FileOperations.(*eventfd.EventOperations).Signal(1) - eventFile.DecRef() + // Notify the event file if one was specified. This needs to happen + // *after* queueing the result to avoid racing with the thread we may + // wake up. + if eventFile != nil { + eventFile.FileOperations.(*eventfd.EventOperations).Signal(1) + eventFile.DecRef() + } } } // submitCallback processes a single callback. -func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Addr) error { +func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error { file := t.GetFile(cb.FD) if file == nil { // File not found. @@ -318,7 +273,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad // Was there an eventFD? Extract it. var eventFile *fs.File - if cb.Flags&_IOCB_FLAG_RESFD != 0 { + if cb.Flags&linux.IOCB_FLAG_RESFD != 0 { eventFile = t.GetFile(cb.ResFD) if eventFile == nil { // Bad FD. @@ -340,7 +295,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad // Check offset for reads/writes. switch cb.OpCode { - case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV, _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV: + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: if cb.Offset < 0 { return syserror.EINVAL } @@ -366,7 +321,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad // Perform the request asynchronously. file.IncRef() - fs.Async(func() { performCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile) }) + t.QueueAIO(getAIOCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile)) // All set. return nil @@ -395,7 +350,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Copy in this callback. - var cb ioCallback + var cb linux.IOCallback cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) if _, err := t.CopyIn(cbAddr, &cb); err != nil { @@ -424,10 +379,4 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc return uintptr(nrEvents), nil, nil } -// IoCancel implements linux syscall io_cancel(2). -// -// It is not presently supported (ENOSYS indicates no support on this -// architecture). -func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ENOSYS -} +// LINT.ThenChange(vfs2/aio.go) diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 8347617bd..2797c6a72 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -900,14 +900,20 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) { +func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { a := file.Async(fasync.New).(*fasync.FileAsync) if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return syserror.EINVAL + } pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who)) a.SetOwnerProcessGroup(t, pg) + } else { + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who)) + a.SetOwnerThreadGroup(t, tg) } - tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who)) - a.SetOwnerThreadGroup(t, tg) + return nil } // Fcntl implements linux syscall fcntl(2). @@ -935,10 +941,10 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - t.FDTable().SetFlags(fd, kernel.FDFlags{ + err := t.FDTable().SetFlags(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) - return 0, nil, nil + return 0, nil, err case linux.F_GETFL: return uintptr(file.Flags().ToLinux()), nil, nil case linux.F_SETFL: @@ -1042,8 +1048,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - fSetOwn(t, file, args[2].Int()) - return 0, nil, nil + return 0, nil, fSetOwn(t, file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1111,17 +1116,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } -// LINT.ThenChange(vfs2/fd.go) - -const ( - _FADV_NORMAL = 0 - _FADV_RANDOM = 1 - _FADV_SEQUENTIAL = 2 - _FADV_WILLNEED = 3 - _FADV_DONTNEED = 4 - _FADV_NOREUSE = 5 -) - // Fadvise64 implements linux syscall fadvise64(2). // This implementation currently ignores the provided advice. func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { @@ -1146,12 +1140,12 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } switch advice { - case _FADV_NORMAL: - case _FADV_RANDOM: - case _FADV_SEQUENTIAL: - case _FADV_WILLNEED: - case _FADV_DONTNEED: - case _FADV_NOREUSE: + case linux.POSIX_FADV_NORMAL: + case linux.POSIX_FADV_RANDOM: + case linux.POSIX_FADV_SEQUENTIAL: + case linux.POSIX_FADV_WILLNEED: + case linux.POSIX_FADV_DONTNEED: + case linux.POSIX_FADV_NOREUSE: default: return 0, nil, syserror.EINVAL } @@ -1160,8 +1154,6 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } -// LINT.IfChange - func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b68261f72..f04d78856 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch cmd { case linux.FUTEX_WAIT: // WAIT uses a relative timeout. - mask = ^uint32(0) + mask = linux.FUTEX_BITSET_MATCH_ANY var timeoutDur time.Duration if !forever { timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond @@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ENOSYS } } + +// SetRobustList implements linux syscall set_robust_list(2). +func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + head := args[0].Pointer() + length := args[1].SizeT() + + if length != uint(linux.SizeOfRobustListHead) { + return 0, nil, syserror.EINVAL + } + t.SetRobustList(head) + return 0, nil, nil +} + +// GetRobustList implements linux syscall get_robust_list(2). +func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + tid := args[0].Int() + head := args[1].Pointer() + size := args[2].Pointer() + + if tid < 0 { + return 0, nil, syserror.EINVAL + } + + ot := t + if tid != 0 { + if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil { + return 0, nil, syserror.ESRCH + } + } + + // Copy out head pointer. + if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil { + return 0, nil, err + } + + // Copy out size, which is a constant. + if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil { + return 0, nil, err + } + + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 9f93f4354..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "vfs2", srcs = [ + "aio.go", "epoll.go", "eventfd.go", "execve.go", @@ -40,6 +41,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/bits", + "//pkg/context", "//pkg/fspath", "//pkg/gohacks", "//pkg/sentry/arch", @@ -52,11 +54,13 @@ go_library( "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/fasync", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/loader", "//pkg/sentry/memmap", + "//pkg/sentry/mm", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix/transport", @@ -68,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go new file mode 100644 index 000000000..e5cdefc50 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -0,0 +1,216 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/mm" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// IoSubmit implements linux syscall io_submit(2). +func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := args[0].Uint64() + nrEvents := args[1].Int() + addr := args[2].Pointer() + + if nrEvents < 0 { + return 0, nil, syserror.EINVAL + } + + for i := int32(0); i < nrEvents; i++ { + // Copy in the address. + cbAddrNative := t.Arch().Native(0) + if _, err := t.CopyIn(addr, cbAddrNative); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err + } + + // Copy in this callback. + var cb linux.IOCallback + cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) + if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if i > 0 { + // Some have been successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err + } + + // Process this callback. + if err := submitCallback(t, id, &cb, cbAddr); err != nil { + if i > 0 { + // Partial success. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err + } + + // Advance to the next one. + addr += usermem.Addr(t.Arch().Width()) + } + + return uintptr(nrEvents), nil, nil +} + +// submitCallback processes a single callback. +func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error { + if cb.Reserved2 != 0 { + return syserror.EINVAL + } + + fd := t.GetFileVFS2(cb.FD) + if fd == nil { + return syserror.EBADF + } + defer fd.DecRef() + + // Was there an eventFD? Extract it. + var eventFD *vfs.FileDescription + if cb.Flags&linux.IOCB_FLAG_RESFD != 0 { + eventFD = t.GetFileVFS2(cb.ResFD) + if eventFD == nil { + return syserror.EBADF + } + defer eventFD.DecRef() + + // Check that it is an eventfd. + if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok { + return syserror.EINVAL + } + } + + ioseq, err := memoryFor(t, cb) + if err != nil { + return err + } + + // Check offset for reads/writes. + switch cb.OpCode { + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: + if cb.Offset < 0 { + return syserror.EINVAL + } + } + + // Prepare the request. + aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id) + if !ok { + return syserror.EINVAL + } + if ready := aioCtx.Prepare(); !ready { + // Context is busy. + return syserror.EAGAIN + } + + if eventFD != nil { + // The request is set. Make sure there's a ref on the file. + // + // This is necessary when the callback executes on completion, + // which is also what will release this reference. + eventFD.IncRef() + } + + // Perform the request asynchronously. + fd.IncRef() + t.QueueAIO(getAIOCallback(t, fd, eventFD, cbAddr, cb, ioseq, aioCtx)) + return nil +} + +func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { + return func(ctx context.Context) { + if aioCtx.Dead() { + aioCtx.CancelPendingRequest() + return + } + ev := &linux.IOEvent{ + Data: cb.Data, + Obj: uint64(cbAddr), + } + + var err error + switch cb.OpCode { + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV: + ev.Result, err = fd.PRead(ctx, ioseq, cb.Offset, vfs.ReadOptions{}) + case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV: + ev.Result, err = fd.PWrite(ctx, ioseq, cb.Offset, vfs.WriteOptions{}) + case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC: + err = fd.Sync(ctx) + } + + // Update the result. + if err != nil { + err = slinux.HandleIOErrorVFS2(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", fd) + ev.Result = -int64(kernel.ExtractErrno(err, 0)) + } + + fd.DecRef() + + // Queue the result for delivery. + aioCtx.FinishRequest(ev) + + // Notify the event file if one was specified. This needs to happen + // *after* queueing the result to avoid racing with the thread we may + // wake up. + if eventFD != nil { + eventFD.Impl().(*eventfd.EventFileDescription).Signal(1) + eventFD.DecRef() + } + } +} + +// memoryFor returns appropriate memory for the given callback. +func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) { + bytes := int(cb.Bytes) + if bytes < 0 { + // Linux also requires that this field fit in ssize_t. + return usermem.IOSequence{}, syserror.EINVAL + } + + // Since this I/O will be asynchronous with respect to t's task goroutine, + // we have no guarantee that t's AddressSpace will be active during the + // I/O. + switch cb.OpCode { + case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE: + return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + AddressSpaceActive: false, + }) + + case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV: + return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{ + AddressSpaceActive: false, + }) + + case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP: + return usermem.IOSequence{}, nil + + default: + // Not a supported command. + return usermem.IOSequence{}, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 6006758a5..517394ba9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -17,10 +17,13 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/fasync" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" ) @@ -134,10 +137,10 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - t.FDTable().SetFlags(fd, kernel.FDFlags{ + err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) - return 0, nil, nil + return 0, nil, err case linux.F_GETFL: return uintptr(file.StatusFlags()), nil, nil case linux.F_SETFL: @@ -152,6 +155,41 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } return uintptr(n), nil, nil + case linux.F_GETOWN: + owner, hasOwner := getAsyncOwner(t, file) + if !hasOwner { + return 0, nil, nil + } + if owner.Type == linux.F_OWNER_PGRP { + return uintptr(-owner.PID), nil, nil + } + return uintptr(owner.PID), nil, nil + case linux.F_SETOWN: + who := args[2].Int() + ownerType := int32(linux.F_OWNER_PID) + if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return 0, nil, syserror.EINVAL + } + ownerType = linux.F_OWNER_PGRP + who = -who + } + return 0, nil, setAsyncOwner(t, file, ownerType, who) + case linux.F_GETOWN_EX: + owner, hasOwner := getAsyncOwner(t, file) + if !hasOwner { + return 0, nil, nil + } + _, err := t.CopyOut(args[2].Pointer(), &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + var owner linux.FOwnerEx + n, err := t.CopyIn(args[2].Pointer(), &owner) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, setAsyncOwner(t, file, owner.Type, owner.PID) case linux.F_GETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -167,8 +205,151 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } err := tmpfs.AddSeals(file, args[2].Uint()) return 0, nil, err + case linux.F_SETLK, linux.F_SETLKW: + return 0, nil, posixLock(t, args, file, cmd) default: // TODO(gvisor.dev/issue/2920): Everything else is not yet supported. return 0, nil, syserror.EINVAL } } + +func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwnerEx, hasOwner bool) { + a := fd.AsyncHandler() + if a == nil { + return linux.FOwnerEx{}, false + } + + ot, otg, opg := a.(*fasync.FileAsync).Owner() + switch { + case ot != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + }, true + case otg != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + }, true + case opg != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + }, true + default: + return linux.FOwnerEx{}, true + } +} + +func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { + switch ownerType { + case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: + // Acceptable type. + default: + return syserror.EINVAL + } + + a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + if pid == 0 { + a.ClearOwner() + return nil + } + + switch ownerType { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) + if task == nil { + return syserror.ESRCH + } + a.SetOwnerTask(t, task) + return nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(pid)) + if tg == nil { + return syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(pid)) + if pg == nil { + return syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return nil + default: + return syserror.EINVAL + } +} + +func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error { + // Copy in the lock request. + flockAddr := args[2].Pointer() + var flock linux.Flock + if _, err := t.CopyIn(flockAddr, &flock); err != nil { + return err + } + + var blocker lock.Blocker + if cmd == linux.F_SETLKW { + blocker = t + } + + switch flock.Type { + case linux.F_RDLCK: + if !file.IsReadable() { + return syserror.EBADF + } + return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + + case linux.F_WRLCK: + if !file.IsWritable() { + return syserror.EBADF + } + return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker) + + case linux.F_UNLCK: + return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence) + + default: + return syserror.EINVAL + } +} + +// Fadvise64 implements fadvise64(2). +// This implementation currently ignores the provided advice. +func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + length := args[2].Int64() + advice := args[3].Int() + + // Note: offset is allowed to be negative. + if length < 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // If the FD refers to a pipe or FIFO, return error. + if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { + return 0, nil, syserror.ESPIPE + } + + switch advice { + case linux.POSIX_FADV_NORMAL: + case linux.POSIX_FADV_RANDOM: + case linux.POSIX_FADV_SEQUENTIAL: + case linux.POSIX_FADV_WILLNEED: + case linux.POSIX_FADV_DONTNEED: + case linux.POSIX_FADV_NOREUSE: + default: + return 0, nil, syserror.EINVAL + } + + // Sure, whatever. + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index 46d3e189c..6b14c2bef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -106,7 +107,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall addr := args[0].Pointer() mode := args[1].ModeT() dev := args[2].Uint() - return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev) } // Mknodat implements Linux syscall mknodat(2). @@ -115,10 +116,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca addr := args[1].Pointer() mode := args[2].ModeT() dev := args[3].Uint() - return 0, nil, mknodat(t, dirfd, addr, mode, dev) + return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev) } -func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -128,9 +129,14 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint return err } defer tpop.Release() + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + if mode.FileType() == 0 { + mode |= linux.ModeRegular + } major, minor := linux.DecodeDeviceID(dev) return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ - Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + Mode: mode &^ linux.FileMode(t.FSContext().Umask()), DevMajor: uint32(major), DevMinor: minor, }) @@ -239,6 +245,55 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd }) } +// Fallocate implements linux system call fallocate(2). +func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + mode := args[1].Uint64() + offset := args[2].Int64() + length := args[3].Int64() + + file := t.GetFileVFS2(fd) + + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + if !file.IsWritable() { + return 0, nil, syserror.EBADF + } + + if mode != 0 { + return 0, nil, syserror.ENOTSUP + } + + if offset < 0 || length <= 0 { + return 0, nil, syserror.EINVAL + } + + size := offset + length + + if size < 0 { + return 0, nil, syserror.EFBIG + } + + limit := limits.FromContext(t).Get(limits.FileSize).Cur + + if uint64(size) >= limit { + t.SendSignal(&arch.SignalInfo{ + Signo: int32(linux.SIGXFSZ), + Code: arch.SignalInfoUser, + }) + return 0, nil, syserror.EFBIG + } + + return 0, nil, file.Impl().Allocate(t, mode, uint64(offset), uint64(length)) + + // File length modified, generate notification. + // TODO(gvisor.dev/issue/1479): Reenable when Inotify is ported. + // file.Dirent.InotifyEvent(linux.IN_MODIFY, 0) +} + // Rmdir implements Linux syscall rmdir(2). func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { pathAddr := args[0].Pointer() @@ -313,6 +368,9 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath if err != nil { return err } + if len(target) == 0 { + return syserror.ENOENT + } linkpath, err := copyInPath(t, linkpathAddr) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go index 7d50b6a16..5d98134a5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/inotify.go +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -81,7 +81,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern // "EINVAL: The given event mask contains no valid events." // -- inotify_add_watch(2) - if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 { + if mask&linux.ALL_INOTIFY_BITS == 0 { return 0, nil, syserror.EINVAL } @@ -116,8 +116,11 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern } defer d.DecRef() - fd = ino.AddWatch(d.Dentry(), mask) - return uintptr(fd), nil, err + fd, err = ino.AddWatch(d.Dentry(), mask) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil } // InotifyRmWatch implements the inotify_rm_watch() syscall. diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 5a2418da9..fd6ab94b2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -15,6 +15,7 @@ package vfs2 import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -30,6 +31,77 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } defer file.DecRef() + // Handle ioctls that apply to all FDs. + switch args[1].Int() { + case linux.FIONCLEX: + t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + CloseOnExec: false, + }) + return 0, nil, nil + + case linux.FIOCLEX: + t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + CloseOnExec: true, + }) + return 0, nil, nil + + case linux.FIONBIO: + var set int32 + if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + return 0, nil, err + } + flags := file.StatusFlags() + if set != 0 { + flags |= linux.O_NONBLOCK + } else { + flags &^= linux.O_NONBLOCK + } + return 0, nil, file.SetStatusFlags(t, t.Credentials(), flags) + + case linux.FIOASYNC: + var set int32 + if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + return 0, nil, err + } + flags := file.StatusFlags() + if set != 0 { + flags |= linux.O_ASYNC + } else { + flags &^= linux.O_ASYNC + } + file.SetStatusFlags(t, t.Credentials(), flags) + return 0, nil, nil + + case linux.FIOGETOWN, linux.SIOCGPGRP: + var who int32 + owner, hasOwner := getAsyncOwner(t, file) + if hasOwner { + if owner.Type == linux.F_OWNER_PGRP { + who = -owner.PID + } else { + who = owner.PID + } + } + _, err := t.CopyOut(args[2].Pointer(), &who) + return 0, nil, err + + case linux.FIOSETOWN, linux.SIOCSPGRP: + var who int32 + if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil { + return 0, nil, err + } + ownerType := int32(linux.F_OWNER_PID) + if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return 0, nil, syserror.EINVAL + } + ownerType = linux.F_OWNER_PGRP + who = -who + } + return 0, nil, setAsyncOwner(t, file, ownerType, who) + } + ret, err := file.Ioctl(t, t.MemoryManager(), args) return ret, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index adeaa39cc..ea337de7c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Silently allow MS_NOSUID, since we don't implement set-id bits // anyway. - const unsupportedFlags = linux.MS_NODEV | - linux.MS_NODIRATIME | linux.MS_STRICTATIME + const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME // Linux just allows passing any flags to mount(2) - it won't fail when // unknown or unsupported flags are passed. Since we don't implement @@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { opts.Flags.NoExec = true } + if flags&linux.MS_NODEV == linux.MS_NODEV { + opts.Flags.NoDev = true + } + if flags&linux.MS_NOSUID == linux.MS_NOSUID { + opts.Flags.NoSUID = true + } if flags&linux.MS_RDONLY == linux.MS_RDONLY { opts.ReadOnly = true } diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 7f9debd4a..cd25597a7 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -606,3 +606,36 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall newoff, err := file.Seek(t, offset, whence) return uintptr(newoff), nil, err } + +// Readahead implements readahead(2). +func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the file is readable. + if !file.IsReadable() { + return 0, nil, syserror.EBADF + } + + // Check that the size is valid. + if int(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Return EINVAL; if the underlying file type does not support readahead, + // then Linux will return EINVAL to indicate as much. In the future, we + // may extend this function to actually support readahead hints. + return 0, nil, syserror.EINVAL +} diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 09ecfed26..6daedd173 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Mask: linux.STATX_SIZE, Size: uint64(length), }, + NeedWritePerm: true, }) return 0, nil, handleSetSizeError(t, err) } @@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef() + if !file.IsWritable() { + return 0, nil, syserror.EINVAL + } + err := file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_SIZE, diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..63ab11f8c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,15 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) @@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } else { n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("not possible") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } @@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + if n == 0 { + return 0, nil, err + } + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + if _, err := t.CopyIn(offsetAddr, &offset); err != nil { return 0, nil, err } + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + if offset != -1 { + spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) + offset += spliceN + } else { + spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + if readN == 0 && err == io.EOF { + // We reached the end of the file. Eat the + // error and exit the loop. + err = nil + break } - if err := t.Block(outCh); err != nil { - return 0, nil, err + n += readN + if err != nil { + break + } + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:n] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only + // report the bytes that were actually + // written, and rewind the offset. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset != -1 { + offset -= notWritten + } + break + } + } + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } + if err != nil { + break + } + } + } + + if offsetAddr != 0 { + // Copy out the new offset. + if _, err := t.CopyOut(offsetAddr, offset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 365250b0b..0d0ebf46a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -65,10 +65,8 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel nbytes := args[2].Int64() flags := args[3].Uint() - if offset < 0 { - return 0, nil, syserror.EINVAL - } - if nbytes < 0 { + // Check for negative values and overflow. + if offset < 0 || offset+nbytes < 0 { return 0, nil, syserror.EINVAL } if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { @@ -81,7 +79,37 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel } defer file.DecRef() - // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of - // [offset, offset+nbytes). - return 0, nil, file.Sync(t) + // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support + // is a full-file sync, i.e. fsync(2). As a result, there are severe + // limitations on how much we support sync_file_range: + // - In Linux, sync_file_range(2) doesn't write out the file's metadata, even + // if the file size is changed. We do. + // - We always sync the entire file instead of [offset, offset+nbytes). + // - We do not support the use of WAIT_BEFORE without WAIT_AFTER. For + // correctness, we would have to perform a write-out every time WAIT_BEFORE + // was used, but this would be much more expensive than expected if there + // were no write-out operations in progress. + // - Whenever WAIT_AFTER is used, we sync the file. + // - Ignore WRITE. If this flag is used with WAIT_AFTER, then the file will + // be synced anyway. If this flag is used without WAIT_AFTER, then it is + // safe (and less expensive) to do nothing, because the syscall will not + // wait for the write-out to complete--we only need to make sure that the + // next time WAIT_BEFORE or WAIT_AFTER are used, the write-out completes. + // - According to fs/sync.c, WAIT_BEFORE|WAIT_AFTER "will detect any I/O + // errors or ENOSPC conditions and will return those to the caller, after + // clearing the EIO and ENOSPC flags in the address_space." We don't do + // this. + + if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 && + flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 { + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, syserror.ENOSYS + } + + if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 { + if err := file.Sync(t); err != nil { + return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + } + } + return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 428a95fbc..c576d9475 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) @@ -92,7 +92,7 @@ func Override() { s.Table[162] = syscalls.Supported("sync", Sync) s.Table[165] = syscalls.Supported("mount", Mount) s.Table[166] = syscalls.Supported("umount2", Umount2) - delete(s.Table, 187) // readahead + s.Table[187] = syscalls.Supported("readahead", Readahead) s.Table[188] = syscalls.Supported("setxattr", Setxattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr) @@ -105,14 +105,10 @@ func Override() { s.Table[197] = syscalls.Supported("removexattr", Removexattr) s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr) - delete(s.Table, 206) // io_setup - delete(s.Table, 207) // io_destroy - delete(s.Table, 208) // io_getevents - delete(s.Table, 209) // io_submit - delete(s.Table, 210) // io_cancel + s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) s.Table[213] = syscalls.Supported("epoll_create", EpollCreate) s.Table[217] = syscalls.Supported("getdents64", Getdents64) - delete(s.Table, 221) // fdavise64 + s.Table[221] = syscalls.PartiallySupported("fadvise64", Fadvise64, "The syscall is 'supported', but ignores all provided advice.", nil) s.Table[232] = syscalls.Supported("epoll_wait", EpollWait) s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl) s.Table[235] = syscalls.Supported("utimes", Utimes) @@ -142,7 +138,7 @@ func Override() { s.Table[282] = syscalls.Supported("signalfd", Signalfd) s.Table[283] = syscalls.Supported("timerfd_create", TimerfdCreate) s.Table[284] = syscalls.Supported("eventfd", Eventfd) - delete(s.Table, 285) // fallocate + s.Table[285] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil) s.Table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime) s.Table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime) s.Table[288] = syscalls.Supported("accept4", Accept4) diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go index 65868cb26..cd1b95117 100644 --- a/pkg/sentry/time/parameters.go +++ b/pkg/sentry/time/parameters.go @@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par // // The log level is determined by the error severity. func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) { - fn := log.Debugf - if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() { + magNS := int64(errorNS.Magnitude()) + if magNS <= 10*time.Microsecond.Nanoseconds() { + // Don't log small errors. + return + } + fn := log.Infof + if magNS > time.Millisecond.Nanoseconds() { + // Upgrade large errors to warning. fn = log.Warningf - } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() { - fn = log.Infof } fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 16d9f3a28..642769e7c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -44,6 +44,7 @@ go_library( "filesystem_impl_util.go", "filesystem_type.go", "inotify.go", + "lock.go", "mount.go", "mount_unsafe.go", "options.go", @@ -72,7 +73,6 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/uniqueid", - "//pkg/sentry/vfs/lock", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 66f3105bd..4b9faf2ea 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -169,8 +169,6 @@ This construction, which is essentially a type-safe analogue to Linux's - binder, which is similarly far too incomplete to use. - - whitelistfs, which we are already actively attempting to remove. - - Save/restore. For instance, it is unclear if the current implementation of the `state` package supports the inheritance pattern described above. diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index b7c6b60b8..641e3e502 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -300,12 +300,15 @@ func (d *anonDentry) DecRef() { // InotifyWithParent implements DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *anonDentry) InotifyWithParent(events uint32, cookie uint32, et EventType) {} +// Although Linux technically supports inotify on pseudo filesystems (inotify +// is implemented at the vfs layer), it is not particularly useful. It is left +// unimplemented until someone actually needs it. +func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {} // Watches implements DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *anonDentry) Watches() *Watches { return nil } + +// OnZeroWatches implements Dentry.OnZeroWatches. +func (d *anonDentry) OnZeroWatches() {} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 24af13eb1..cea3e6955 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -113,12 +113,29 @@ type DentryImpl interface { // // Note that the events may not actually propagate up to the user, depending // on the event masks. - InotifyWithParent(events uint32, cookie uint32, et EventType) + InotifyWithParent(events, cookie uint32, et EventType) // Watches returns the set of inotify watches for the file corresponding to // the Dentry. Dentries that are hard links to the same underlying file // share the same watches. + // + // Watches may return nil if the dentry belongs to a FilesystemImpl that + // does not support inotify. If an implementation returns a non-nil watch + // set, it must always return a non-nil watch set. Likewise, if an + // implementation returns a nil watch set, it must always return a nil watch + // set. + // + // The caller does not need to hold a reference on the dentry. Watches() *Watches + + // OnZeroWatches is called whenever the number of watches on a dentry drops + // to zero. This is needed by some FilesystemImpls (e.g. gofer) to manage + // dentry lifetime. + // + // The caller does not need to hold a reference on the dentry. OnZeroWatches + // may acquire inotify locks, so to prevent deadlock, no inotify locks should + // be held by the caller. + OnZeroWatches() } // IncRef increments d's reference count. @@ -149,17 +166,26 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } -// InotifyWithParent notifies all watches on the inodes for this dentry and +// InotifyWithParent notifies all watches on the targets represented by d and // its parent of events. -func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et EventType) { +func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) { d.impl.InotifyWithParent(events, cookie, et) } // Watches returns the set of inotify watches associated with d. +// +// Watches will return nil if d belongs to a FilesystemImpl that does not +// support inotify. func (d *Dentry) Watches() *Watches { return d.impl.Watches() } +// OnZeroWatches performs cleanup tasks whenever the number of watches on a +// dentry drops to zero. +func (d *Dentry) OnZeroWatches() { + d.impl.OnZeroWatches() +} + // The following functions are exported so that filesystem implementations can // use them. The vfs package, and users of VFS, should not call these // functions. diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 13c48824e..0c42574db 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -42,11 +42,20 @@ type FileDescription struct { // operations. refs int64 + // flagsMu protects statusFlags and asyncHandler below. + flagsMu sync.Mutex + // statusFlags contains status flags, "initialized by open(2) and possibly - // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic - // memory operations. + // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic + // memory operations when it does not need to be synchronized with an + // access to asyncHandler. statusFlags uint32 + // asyncHandler handles O_ASYNC signal generation. It is set with the + // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must + // also be set by fcntl(2). + asyncHandler FileAsync + // epolls is the set of epollInterests registered for this FileDescription. // epolls is protected by epollMu. epollMu sync.Mutex @@ -82,8 +91,7 @@ type FileDescription struct { // FileDescriptionOptions contains options to FileDescription.Init(). type FileDescriptionOptions struct { - // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is - // usually only the case if O_DIRECT would actually have an effect. + // If AllowDirectIO is true, allow O_DIRECT to be set on the file. AllowDirectIO bool // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE. @@ -193,6 +201,13 @@ func (fd *FileDescription) DecRef() { fd.vd.mount.EndWrite() } fd.vd.DecRef() + fd.flagsMu.Lock() + // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } + fd.asyncHandler = nil + fd.flagsMu.Unlock() } else if refs < 0 { panic("FileDescription.DecRef() called without holding a reference") } @@ -276,7 +291,18 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede } // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK - atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) + fd.flagsMu.Lock() + if fd.asyncHandler != nil { + // Use fd.statusFlags instead of oldFlags, which may have become outdated, + // to avoid double registering/unregistering. + if fd.statusFlags&linux.O_ASYNC == 0 && flags&linux.O_ASYNC != 0 { + fd.asyncHandler.Register(fd) + } else if fd.statusFlags&linux.O_ASYNC != 0 && flags&linux.O_ASYNC == 0 { + fd.asyncHandler.Unregister(fd) + } + } + fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags) + fd.flagsMu.Unlock() return nil } @@ -328,6 +354,10 @@ type FileDescriptionImpl interface { // represented by the FileDescription. StatFS(ctx context.Context) (linux.Statfs, error) + // Allocate grows file represented by FileDescription to offset + length bytes. + // Only mode == 0 is supported currently. + Allocate(ctx context.Context, mode, offset, length uint64) error + // waiter.Waitable methods may be used to poll for I/O events. waiter.Waitable @@ -438,14 +468,10 @@ type FileDescriptionImpl interface { UnlockBSD(ctx context.Context, uid lock.UniqueID) error // LockPOSIX tries to acquire a POSIX-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): POSIX-style file locking - LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error + LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error // UnlockPOSIX releases a POSIX-style advisory file lock. - // - // TODO(gvisor.dev/issue/1480): POSIX-style file locking - UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error + UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error } // Dirent holds the information contained in struct linux_dirent64. @@ -537,17 +563,23 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.impl.StatFS(ctx) } -// Readiness returns fd's I/O readiness. +// Readiness implements waiter.Waitable.Readiness. +// +// It returns fd's I/O readiness. func (fd *FileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fd.impl.Readiness(mask) } -// EventRegister registers e for I/O readiness events in mask. +// EventRegister implements waiter.Waitable.EventRegister. +// +// It registers e for I/O readiness events in mask. func (fd *FileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { fd.impl.EventRegister(e, mask) } -// EventUnregister unregisters e for I/O readiness events. +// EventUnregister implements waiter.Waitable.EventUnregister. +// +// It unregisters e for I/O readiness events. func (fd *FileDescription) EventUnregister(e *waiter.Entry) { fd.impl.EventUnregister(e) } @@ -764,3 +796,42 @@ func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, func (fd *FileDescription) UnlockBSD(ctx context.Context) error { return fd.impl.UnlockBSD(ctx, fd) } + +// LockPOSIX locks a POSIX-style file range lock. +func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error { + return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block) +} + +// UnlockPOSIX unlocks a POSIX-style file range lock. +func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error { + return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence) +} + +// A FileAsync sends signals to its owner when w is ready for IO. This is only +// implemented by pkg/sentry/fasync:FileAsync, but we unfortunately need this +// interface to avoid circular dependencies. +type FileAsync interface { + Register(w waiter.Waitable) + Unregister(w waiter.Waitable) +} + +// AsyncHandler returns the FileAsync for fd. +func (fd *FileDescription) AsyncHandler() FileAsync { + fd.flagsMu.Lock() + defer fd.flagsMu.Unlock() + return fd.asyncHandler +} + +// SetAsyncHandler sets fd.asyncHandler if it has not been set before and +// returns it. +func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsync { + fd.flagsMu.Lock() + defer fd.flagsMu.Unlock() + if fd.asyncHandler == nil { + fd.asyncHandler = newHandler() + if fd.statusFlags&linux.O_ASYNC != 0 { + fd.asyncHandler.Register(fd) + } + } + return fd.asyncHandler +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index af7213dfd..6b8b4ad49 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/vfs/lock" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -57,6 +56,12 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err return linux.Statfs{}, syserror.ENOSYS } +// Allocate implements FileDescriptionImpl.Allocate analogously to +// fallocate called on regular file, directory or FIFO in Linux. +func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + // Readiness implements waiter.Waitable.Readiness analogously to // file_operations::poll == NULL in Linux. func (FileDescriptionDefaultImpl) Readiness(mask waiter.EventMask) waiter.EventMask { @@ -159,6 +164,11 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) // implementations of non-directory I/O methods that return EISDIR. type DirectoryFileDescriptionDefaultImpl struct{} +// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate. +func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EISDIR +} + // PRead implements FileDescriptionImpl.PRead. func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.EISDIR @@ -328,7 +338,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src writable, ok := fd.data.(WritableDynamicBytesSource) if !ok { - return 0, syserror.EINVAL + return 0, syserror.EIO } n, err := writable.Write(ctx, src, offset) if err != nil { @@ -369,14 +379,19 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M // LockFD may be used by most implementations of FileDescriptionImpl.Lock* // functions. Caller must call Init(). type LockFD struct { - locks *lock.FileLocks + locks *FileLocks } // Init initializes fd with FileLocks to use. -func (fd *LockFD) Init(locks *lock.FileLocks) { +func (fd *LockFD) Init(locks *FileLocks) { fd.locks = locks } +// Locks returns the locks associated with this file. +func (fd *LockFD) Locks() *FileLocks { + return fd.locks +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error { return fd.locks.LockBSD(uid, t, block) @@ -388,17 +403,6 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { - return fd.locks.LockPOSIX(uid, t, rng, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { - fd.locks.UnlockPOSIX(uid, rng) - return nil -} - // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. type NoLockFD struct{} @@ -414,11 +418,11 @@ func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { +func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return syserror.ENOLCK } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, rng fslock.LockRange) error { +func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return syserror.ENOLCK } diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 5061f6ac9..3b7e1c273 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -155,11 +155,11 @@ func TestGenCountFD(t *testing.T) { } // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) + if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO { + t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) + if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO { + t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 1edd584c9..6bb9ca180 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -524,8 +524,6 @@ type FilesystemImpl interface { // // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error - - // TODO(gvisor.dev/issue/1479): inotify_add_watch() } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md new file mode 100644 index 000000000..e7da49faa --- /dev/null +++ b/pkg/sentry/vfs/g3doc/inotify.md @@ -0,0 +1,210 @@ +# Inotify + +Inotify is a mechanism for monitoring filesystem events in Linux--see +inotify(7). An inotify instance can be used to monitor files and directories for +modifications, creation/deletion, etc. The inotify API consists of system calls +that create inotify instances (inotify_init/inotify_init1) and add/remove +watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are +generated from various places in the sentry, including the syscall layer, the +vfs layer, the process fd table, and within each filesystem implementation. This +document outlines the implementation details of inotify in VFS2. + +## Inotify Objects + +Inotify data structures are implemented in the vfs package. + +### vfs.Inotify + +Inotify instances are represented by vfs.Inotify objects, which implement +vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a +pseudo-filesystem (anonfs). Each inotify instance receives events from a set of +vfs.Watch objects, which can be modified with inotify_add_watch(2) and +inotify_rm_watch(2). An application can retrieve events by reading the inotify +fd. + +### vfs.Watches + +The set of all watches held on a single file (i.e., the watch target) is stored +in vfs.Watches. Each watch will belong to a different inotify instance (an +instance can only have one watch on any watch target). The watches are stored in +a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions +to a single file will all share the same vfs.Watches. Activity on the target +causes its vfs.Watches to generate notifications on its watches’ inotify +instances. + +### vfs.Watch + +A single watch, owned by one inotify instance and applied to one watch target. +Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch, +which leads to some complicated locking behavior (see Lock Ordering). Whenever a +watch is notified of an event on its target, it will queue events to its inotify +instance for delivery to the user. + +### vfs.Event + +vfs.Event is a simple struct encapsulating all the fields for an inotify event. +It is generated by vfs.Watches and forwarded to the watches' owners. It is +serialized to the user during read(2) syscalls on the associated fs.Inotify's +fd. + +## Lock Ordering + +There are three locks related to the inotify implementation: + +Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue +lock. Watches.mu: the watch set lock, used to protect the collection of watches +on a target. + +The correct lock ordering for inotify code is: + +Inotify.mu -> Watches.mu -> Inotify.evMu. + +Note that we use a distinct lock to protect the inotify event queue. If we +simply used Inotify.mu, we could simultaneously have locks being acquired in the +order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would +cause deadlocks. For instance, adding a watch to an inotify instance would +require locking Inotify.mu, and then adding the same watch to the target would +cause Watches.mu to be held. At the same time, generating an event on the target +would require Watches.mu to be held before iterating through each watch, and +then notifying the owner of each watch would cause Inotify.mu to be held. + +See the vfs package comment to understand how inotify locks fit into the overall +ordering of filesystem locks. + +## Watch Targets in Different Filesystem Implementations + +In Linux, watches reside on inodes at the virtual filesystem layer. As a result, +all hard links and file descriptions on a single file will all share the same +watch set. In VFS2, there is no common inode structure across filesystem types +(some may not even have inodes), so we have to plumb inotify support through +each specific filesystem implementation. Some of the technical considerations +are outlined below. + +### Tmpfs + +For filesystems with inodes, like tmpfs, the design is quite similar to that of +Linux, where watches reside on the inode. + +### Pseudo-filesystems + +Technically, because inotify is implemented at the vfs layer in Linux, +pseudo-filesystems on top of kernfs support inotify passively. However, watches +can only track explicit filesystem operations like read/write, open/close, +mknod, etc., so watches on a target like /proc/self/fd will not generate events +every time a new fd is added or removed. As of this writing, we leave inotify +unimplemented in kernfs and anonfs; it does not seem particularly useful. + +### Gofer Filesystem (fsimpl/gofer) + +The gofer filesystem has several traits that make it difficult to support +inotify: + +* **There are no inodes.** A file is represented as a dentry that holds an + unopened p9 file (and possibly an open FID), through which the Sentry + interacts with the gofer. + * *Solution:* Because there is no inode structure stored in the sandbox, + inotify watches must be held on the dentry. This would be an issue in + the presence of hard links, where multiple dentries would need to share + the same set of watches, but in VFS2, we do not support the internal + creation of hard links on gofer fs. As a result, we make the assumption + that every dentry corresponds to a unique inode. However, the next point + raises an issue with this assumption: +* **The Sentry cannot always be aware of hard links on the remote + filesystem.** There is no way for us to confirm whether two files on the + remote filesystem are actually links to the same inode. QIDs and inodes are + not always 1:1. The assumption that dentries and inodes are 1:1 is + inevitably broken if there are remote hard links that we cannot detect. + * *Solution:* this is an issue with gofer fs in general, not only inotify, + and we will have to live with it. +* **Dentries can be cached, and then evicted.** Dentry lifetime does not + correspond to file lifetime. Because gofer fs is not entirely in-memory, the + absence of a dentry does not mean that the corresponding file does not + exist, nor does a dentry reaching zero references mean that the + corresponding file no longer exists. When a dentry reaches zero references, + it will be cached, in case the file at that path is needed again in the + future. However, the dentry may be evicted from the cache, which will cause + a new dentry to be created next time the same file path is used. The + existing watches will be lost. + * *Solution:* When a dentry reaches zero references, do not cache it if it + has any watches, so we can avoid eviction/destruction. Note that if the + dentry was deleted or invalidated (d.vfsd.IsDead()), we should still + destroy it along with its watches. Additionally, when a dentry’s last + watch is removed, we cache it if it also has zero references. This way, + the dentry can eventually be evicted from memory if it is no longer + needed. +* **Dentries can be invalidated.** Another issue with dentry lifetime is that + the remote file at the file path represented may change from underneath the + dentry. In this case, the next time that the dentry is used, it will be + invalidated and a new dentry will replace it. In this case, it is not clear + what should be done with the watches on the old dentry. + * *Solution:* Silently destroy the watches when invalidation occurs. We + have no way of knowing exactly what happened, when it happens. Inotify + instances on NFS files in Linux probably behave in a similar fashion, + since inotify is implemented at the vfs layer and is not aware of the + complexities of remote file systems. + * An alternative would be to issue some kind of event upon invalidation, + e.g. a delete event, but this has several issues: + * We cannot discern whether the remote file was invalidated because it was + moved, deleted, etc. This information is crucial, because these cases + should result in different events. Furthermore, the watches should only + be destroyed if the file has been deleted. + * Moreover, the mechanism for detecting whether the underlying file has + changed is to check whether a new QID is given by the gofer. This may + result in false positives, e.g. suppose that the server closed and + re-opened the same file, which may result in a new QID. + * Finally, the time of the event may be completely different from the time + of the file modification, since a dentry is not immediately notified + when the underlying file has changed. It would be quite unexpected to + receive the notification when invalidation was triggered, i.e. the next + time the file was accessed within the sandbox, because then the + read/write/etc. operation on the file would not result in the expected + event. + * Another point in favor of the first solution: inotify in Linux can + already be lossy on local filesystems (one of the sacrifices made so + that filesystem performance isn’t killed), and it is lossy on NFS for + similar reasons to gofer fs. Therefore, it is better for inotify to be + silent than to emit incorrect notifications. +* **There may be external users of the remote filesystem.** We can only track + operations performed on the file within the sandbox. This is sufficient + under InteropModeExclusive, but whenever there are external users, the set + of actions we are aware of is incomplete. + * *Solution:* We could either return an error or just issue a warning when + inotify is used without InteropModeExclusive. Although faulty, VFS1 + allows it when the filesystem is shared, and Linux does the same for + remote filesystems (as mentioned above, inotify sits at the vfs level). + +## Dentry Interface + +For events that must be generated above the vfs layer, we provide the following +DentryImpl methods to allow interactions with targets on any FilesystemImpl: + +* **InotifyWithParent()** generates events on the dentry’s watches as well as + its parent’s. +* **Watches()** retrieves the watch set of the target represented by the + dentry. This is used to access and modify watches on a target. +* **OnZeroWatches()** performs cleanup tasks after the last watch is removed + from a dentry. This is needed by gofer fs, which must allow a watched dentry + to be cached once it has no more watches. Most implementations can just do + nothing. Note that OnZeroWatches() must be called after all inotify locks + are released to preserve lock ordering, since it may acquire + FilesystemImpl-specific locks. + +## IN_EXCL_UNLINK + +There are several options that can be set for a watch, specified as part of the +mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some +additional support in each filesystem. + +A watch with IN_EXCL_UNLINK will not generate events for its target if it +corresponds to a path that was unlinked. For instance, if an fd is opened on +“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the +fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This +requires each DentryImpl to keep track of whether it has been unlinked, in order +to determine whether events should be sent to watches with IN_EXCL_UNLINK. + +## IN_ONESHOT + +One-shot watches expire after generating a single event. When an event occurs, +all one-shot watches on the target that successfully generated an event are +removed. Lock ordering can cause the management of one-shot watches to be quite +expensive; see Watches.Notify() for more information. diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 7fa7d2d0c..167b731ac 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -49,9 +49,6 @@ const ( // Inotify represents an inotify instance created by inotify_init(2) or // inotify_init1(2). Inotify implements FileDescriptionImpl. // -// Lock ordering: -// Inotify.mu -> Watches.mu -> Inotify.evMu -// // +stateify savable type Inotify struct { vfsfd FileDescription @@ -122,20 +119,40 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) // Release implements FileDescriptionImpl.Release. Release removes all // watches and frees all resources for an inotify instance. func (i *Inotify) Release() { + var ds []*Dentry + // We need to hold i.mu to avoid a race with concurrent calls to // Inotify.handleDeletion from Watches. There's no risk of Watches // accessing this Inotify after the destructor ends, because we remove all // references to it below. i.mu.Lock() - defer i.mu.Unlock() for _, w := range i.watches { // Remove references to the watch from the watches set on the target. We // don't need to worry about the references from i.watches, since this // file description is about to be destroyed. - w.set.Remove(i.id) + d := w.target + ws := d.Watches() + // Watchable dentries should never return a nil watch set. + if ws == nil { + panic("Cannot remove watch from an unwatchable dentry") + } + ws.Remove(i.id) + if ws.Size() == 0 { + ds = append(ds, d) + } + } + i.mu.Unlock() + + for _, d := range ds { + d.OnZeroWatches() } } +// Allocate implements FileDescription.Allocate. +func (i *Inotify) Allocate(ctx context.Context, mode, offset, length uint64) error { + panic("Allocate should not be called on read-only inotify fds") +} + // EventRegister implements waiter.Waitable. func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) { i.queue.EventRegister(e, mask) @@ -162,12 +179,12 @@ func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { return mask & ready } -// PRead implements FileDescriptionImpl. +// PRead implements FileDescriptionImpl.PRead. func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.ESPIPE } -// PWrite implements FileDescriptionImpl. +// PWrite implements FileDescriptionImpl.PWrite. func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -226,7 +243,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt return writeLen, nil } -// Ioctl implements fs.FileOperations.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch args[1].Int() { case linux.FIONREAD: @@ -272,20 +289,19 @@ func (i *Inotify) queueEvent(ev *Event) { // newWatchLocked creates and adds a new watch to target. // -// Precondition: i.mu must be locked. -func (i *Inotify) newWatchLocked(target *Dentry, mask uint32) *Watch { - targetWatches := target.Watches() +// Precondition: i.mu must be locked. ws must be the watch set for target d. +func (i *Inotify) newWatchLocked(d *Dentry, ws *Watches, mask uint32) *Watch { w := &Watch{ - owner: i, - wd: i.nextWatchIDLocked(), - set: targetWatches, - mask: mask, + owner: i, + wd: i.nextWatchIDLocked(), + target: d, + mask: mask, } // Hold the watch in this inotify instance as well as the watch set on the // target. i.watches[w.wd] = w - targetWatches.Add(w) + ws.Add(w) return w } @@ -297,22 +313,11 @@ func (i *Inotify) nextWatchIDLocked() int32 { return i.nextWatchMinusOne } -// handleDeletion handles the deletion of the target of watch w. It removes w -// from i.watches and a watch removal event is generated. -func (i *Inotify) handleDeletion(w *Watch) { - i.mu.Lock() - _, found := i.watches[w.wd] - delete(i.watches, w.wd) - i.mu.Unlock() - - if found { - i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0)) - } -} - // AddWatch constructs a new inotify watch and adds it to the target. It // returns the watch descriptor returned by inotify_add_watch(2). -func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { +// +// The caller must hold a reference on target. +func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) { // Note: Locking this inotify instance protects the result returned by // Lookup() below. With the lock held, we know for sure the lookup result // won't become stale because it's impossible for *this* instance to @@ -320,8 +325,14 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { i.mu.Lock() defer i.mu.Unlock() + ws := target.Watches() + if ws == nil { + // While Linux supports inotify watches on all filesystem types, watches on + // filesystems like kernfs are not generally useful, so we do not. + return 0, syserror.EPERM + } // Does the target already have a watch from this inotify instance? - if existing := target.Watches().Lookup(i.id); existing != nil { + if existing := ws.Lookup(i.id); existing != nil { newmask := mask if mask&linux.IN_MASK_ADD != 0 { // "Add (OR) events to watch mask for this pathname if it already @@ -329,12 +340,12 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { newmask |= atomic.LoadUint32(&existing.mask) } atomic.StoreUint32(&existing.mask, newmask) - return existing.wd + return existing.wd, nil } // No existing watch, create a new watch. - w := i.newWatchLocked(target, mask) - return w.wd + w := i.newWatchLocked(target, ws, mask) + return w.wd, nil } // RmWatch looks up an inotify watch for the given 'wd' and configures the @@ -353,9 +364,19 @@ func (i *Inotify) RmWatch(wd int32) error { delete(i.watches, wd) // Remove the watch from the watch target. - w.set.Remove(w.OwnerID()) + ws := w.target.Watches() + // AddWatch ensures that w.target has a non-nil watch set. + if ws == nil { + panic("Watched dentry cannot have nil watch set") + } + ws.Remove(w.OwnerID()) + remaining := ws.Size() i.mu.Unlock() + if remaining == 0 { + w.target.OnZeroWatches() + } + // Generate the event for the removal. i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0)) @@ -374,6 +395,13 @@ type Watches struct { ws map[uint64]*Watch } +// Size returns the number of watches held by w. +func (w *Watches) Size() int { + w.mu.Lock() + defer w.mu.Unlock() + return len(w.ws) +} + // Lookup returns the watch owned by an inotify instance with the given id. // Returns nil if no such watch exists. // @@ -424,64 +452,86 @@ func (w *Watches) Remove(id uint64) { return } - if _, ok := w.ws[id]; !ok { - // While there's technically no problem with silently ignoring a missing - // watch, this is almost certainly a bug. - panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id)) + // It is possible for w.Remove() to be called for the same watch multiple + // times. See the treatment of one-shot watches in Watches.Notify(). + if _, ok := w.ws[id]; ok { + delete(w.ws, id) } - delete(w.ws, id) } -// Notify queues a new event with all watches in this set. -func (w *Watches) Notify(name string, events, cookie uint32, et EventType) { - w.NotifyWithExclusions(name, events, cookie, et, false) +// Notify queues a new event with watches in this set. Watches with +// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been +// unlinked. +func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) { + var hasExpired bool + w.mu.RLock() + for _, watch := range w.ws { + if unlinked && watch.ExcludeUnlinked() && et == PathEvent { + continue + } + if watch.Notify(name, events, cookie) { + hasExpired = true + } + } + w.mu.RUnlock() + + if hasExpired { + w.cleanupExpiredWatches() + } } -// NotifyWithExclusions queues a new event with watches in this set. Watches -// with IN_EXCL_UNLINK are skipped if the event is coming from a child that -// has been unlinked. -func (w *Watches) NotifyWithExclusions(name string, events, cookie uint32, et EventType, unlinked bool) { - // N.B. We don't defer the unlocks because Notify is in the hot path of - // all IO operations, and the defer costs too much for small IO - // operations. +// This function is relatively expensive and should only be called where there +// are expired watches. +func (w *Watches) cleanupExpiredWatches() { + // Because of lock ordering, we cannot acquire Inotify.mu for each watch + // owner while holding w.mu. As a result, store expired watches locally + // before removing. + var toRemove []*Watch w.mu.RLock() for _, watch := range w.ws { - if unlinked && watch.ExcludeUnlinkedChildren() && et == PathEvent { - continue + if atomic.LoadInt32(&watch.expired) == 1 { + toRemove = append(toRemove, watch) } - watch.Notify(name, events, cookie) } w.mu.RUnlock() + for _, watch := range toRemove { + watch.owner.RmWatch(watch.wd) + } } -// HandleDeletion is called when the watch target is destroyed to emit -// the appropriate events. +// HandleDeletion is called when the watch target is destroyed. Clear the +// watch set, detach watches from the inotify instances they belong to, and +// generate the appropriate events. func (w *Watches) HandleDeletion() { - w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent) + w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) - // TODO(gvisor.dev/issue/1479): This doesn't work because maps are not copied - // by value. Ideally, we wouldn't have this circular locking so we can just - // notify of IN_DELETE_SELF in the same loop below. - // - // We can't hold w.mu while calling watch.handleDeletion to preserve lock - // ordering w.r.t to the owner inotify instances. Instead, atomically move - // the watches map into a local variable so we can iterate over it safely. - // - // Because of this however, it is possible for the watches' owners to reach - // this inode while the inode has no refs. This is still safe because the - // owners can only reach the inode until this function finishes calling - // watch.handleDeletion below and the inode is guaranteed to exist in the - // meantime. But we still have to be very careful not to rely on inode state - // that may have been already destroyed. + // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for + // the owner of each watch being deleted. Instead, atomically store the + // watches map in a local variable and set it to nil so we can iterate over + // it with the assurance that there will be no concurrent accesses. var ws map[uint64]*Watch w.mu.Lock() ws = w.ws w.ws = nil w.mu.Unlock() + // Remove each watch from its owner's watch set, and generate a corresponding + // watch removal event. for _, watch := range ws { - // TODO(gvisor.dev/issue/1479): consider refactoring this. - watch.handleDeletion() + i := watch.owner + i.mu.Lock() + _, found := i.watches[watch.wd] + delete(i.watches, watch.wd) + + // Release mutex before notifying waiters because we don't control what + // they can do. + i.mu.Unlock() + + // If watch was not found, it was removed from the inotify instance before + // we could get to it, in which case we should not generate an event. + if found { + i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0)) + } } } @@ -490,18 +540,28 @@ func (w *Watches) HandleDeletion() { // +stateify savable type Watch struct { // Inotify instance which owns this watch. + // + // This field is immutable after creation. owner *Inotify // Descriptor for this watch. This is unique across an inotify instance. + // + // This field is immutable after creation. wd int32 - // set is the watch set containing this watch. It belongs to the target file - // of this watch. - set *Watches + // target is a dentry representing the watch target. Its watch set contains this watch. + // + // This field is immutable after creation. + target *Dentry // Events being monitored via this watch. Must be accessed with atomic // memory operations. mask uint32 + + // expired is set to 1 to indicate that this watch is a one-shot that has + // already sent a notification and therefore can be removed. Must be accessed + // with atomic memory operations. + expired int32 } // OwnerID returns the id of the inotify instance that owns this watch. @@ -509,23 +569,29 @@ func (w *Watch) OwnerID() uint64 { return w.owner.id } -// ExcludeUnlinkedChildren indicates whether the watched object should continue -// to be notified of events of its children after they have been unlinked, e.g. -// for an open file descriptor. +// ExcludeUnlinked indicates whether the watched object should continue to be +// notified of events originating from a path that has been unlinked. // -// TODO(gvisor.dev/issue/1479): Implement IN_EXCL_UNLINK. -// We can do this by keeping track of the set of unlinked children in Watches -// to skip notification. -func (w *Watch) ExcludeUnlinkedChildren() bool { +// For example, if "foo/bar" is opened and then unlinked, operations on the +// open fd may be ignored by watches on "foo" and "foo/bar" with IN_EXCL_UNLINK. +func (w *Watch) ExcludeUnlinked() bool { return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0 } -// Notify queues a new event on this watch. -func (w *Watch) Notify(name string, events uint32, cookie uint32) { +// Notify queues a new event on this watch. Returns true if this is a one-shot +// watch that should be deleted, after this event was successfully queued. +func (w *Watch) Notify(name string, events uint32, cookie uint32) bool { + if atomic.LoadInt32(&w.expired) == 1 { + // This is a one-shot watch that is already in the process of being + // removed. This may happen if a second event reaches the watch target + // before this watch has been removed. + return false + } + mask := atomic.LoadUint32(&w.mask) if mask&events == 0 { // We weren't watching for this event. - return + return false } // Event mask should include bits matched from the watch plus all control @@ -534,11 +600,11 @@ func (w *Watch) Notify(name string, events uint32, cookie uint32) { effectiveMask := unmaskableBits | mask matchedEvents := effectiveMask & events w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie)) -} - -// handleDeletion handles the deletion of w's target. -func (w *Watch) handleDeletion() { - w.owner.handleDeletion(w) + if mask&linux.IN_ONESHOT != 0 { + atomic.StoreInt32(&w.expired, 1) + return true + } + return false } // Event represents a struct inotify_event from linux. @@ -606,7 +672,7 @@ func (e *Event) setName(name string) { func (e *Event) sizeOf() int { s := inotifyEventBaseSize + int(e.len) if s < inotifyEventBaseSize { - panic("overflow") + panic("Overflowed event size") } return s } @@ -676,11 +742,15 @@ func InotifyEventFromStatMask(mask uint32) uint32 { } // InotifyRemoveChild sends the appriopriate notifications to the watch sets of -// the child being removed and its parent. +// the child being removed and its parent. Note that unlike most pairs of +// parent/child notifications, the child is notified first in this case. func InotifyRemoveChild(self, parent *Watches, name string) { - self.Notify("", linux.IN_ATTRIB, 0, InodeEvent) - parent.Notify(name, linux.IN_DELETE, 0, InodeEvent) - // TODO(gvisor.dev/issue/1479): implement IN_EXCL_UNLINK. + if self != nil { + self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) + } + if parent != nil { + parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) + } } // InotifyRename sends the appriopriate notifications to the watch sets of the @@ -691,8 +761,14 @@ func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, dirEv = linux.IN_ISDIR } cookie := uniqueid.InotifyCookie(ctx) - oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent) - newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent) + if oldParent != nil { + oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) + } + if newParent != nil { + newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) + } // Somewhat surprisingly, self move events do not have a cookie. - renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent) + if renamed != nil { + renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) + } } diff --git a/pkg/sentry/vfs/lock/lock.go b/pkg/sentry/vfs/lock.go index 724dfe743..6c7583a81 100644 --- a/pkg/sentry/vfs/lock/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -17,9 +17,11 @@ // // The actual implementations can be found in the lock package under // sentry/fs/lock. -package lock +package vfs import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/syserror" ) @@ -56,7 +58,11 @@ func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) { } // LockPOSIX tries to acquire a POSIX-style lock on a file region. -func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error { +func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + rng, err := computeRange(ctx, fd, start, length, whence) + if err != nil { + return err + } if fl.posix.LockRegion(uid, t, rng, block) { return nil } @@ -67,6 +73,37 @@ func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fsloc // // This operation is always successful, even if there did not exist a lock on // the requested region held by uid in the first place. -func (fl *FileLocks) UnlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) { +func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error { + rng, err := computeRange(ctx, fd, start, length, whence) + if err != nil { + return err + } fl.posix.UnlockRegion(uid, rng) + return nil +} + +func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) { + var off int64 + switch whence { + case linux.SEEK_SET: + off = 0 + case linux.SEEK_CUR: + // Note that Linux does not hold any mutexes while retrieving the file + // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk. + curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR) + if err != nil { + return fslock.LockRange{}, err + } + off = curOff + case linux.SEEK_END: + stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE}) + if err != nil { + return fslock.LockRange{}, err + } + off = int64(stat.Size) + default: + return fslock.LockRange{}, syserror.EINVAL + } + + return fslock.ComputeRange(int64(start), int64(length), off) } diff --git a/pkg/sentry/vfs/lock/BUILD b/pkg/sentry/vfs/lock/BUILD deleted file mode 100644 index d9ab063b7..000000000 --- a/pkg/sentry/vfs/lock/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "lock", - srcs = ["lock.go"], - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/sentry/fs/lock", - "//pkg/syserror", - ], -) diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index f223aeda8..dfc8573fd 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -79,6 +79,17 @@ type MountFlags struct { // NoATime is equivalent to MS_NOATIME and indicates that the // filesystem should not update access time in-place. NoATime bool + + // NoDev is equivalent to MS_NODEV and indicates that the + // filesystem should not allow access to devices (special files). + // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE + // filesystems. + NoDev bool + + // NoSUID is equivalent to MS_NOSUID and indicates that the + // filesystem should not honor set-user-ID and set-group-ID bits or + // file capabilities when executing programs. + NoSUID bool } // MountOptions contains options to VirtualFilesystem.MountAt(). @@ -153,6 +164,12 @@ type SetStatOptions struct { // == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask // instead). Stat linux.Statx + + // NeedWritePerm indicates that write permission on the file is needed for + // this operation. This is needed for truncate(2) (note that ftruncate(2) + // does not require the same check--instead, it checks that the fd is + // writable). + NeedWritePerm bool } // BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f9647f90e..33389c1df 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -94,6 +94,37 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linu return syserror.EACCES } +// MayLink determines whether creating a hard link to a file with the given +// mode, kuid, and kgid is permitted. +// +// This corresponds to Linux's fs/namei.c:may_linkat. +func MayLink(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + // Source inode owner can hardlink all they like; otherwise, it must be a + // safe source. + if CanActAsOwner(creds, kuid) { + return nil + } + + // Only regular files can be hard linked. + if mode.FileType() != linux.S_IFREG { + return syserror.EPERM + } + + // Setuid files should not get pinned to the filesystem. + if mode&linux.S_ISUID != 0 { + return syserror.EPERM + } + + // Executable setgid files should not get pinned to the filesystem, but we + // don't support S_IXGRP anyway. + + // Hardlinking to unreadable or unwritable sources is dangerous. + if err := GenericCheckPermissions(creds, MayRead|MayWrite, mode, kuid, kgid); err != nil { + return syserror.EPERM + } + return nil +} + // AccessTypesForOpenFlags returns the access types required to open a file // with the given OpenOptions.Flags. Note that this is NOT the same thing as // the set of accesses permitted for the opened file: @@ -152,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + stat := &opts.Stat if stat.Mask&linux.STATX_SIZE != 0 { limit, err := CheckLimit(ctx, 0, int64(stat.Size)) if err != nil { @@ -184,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return syserror.EPERM } } + if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) { + if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil { + return err + } + } if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { if !CanActAsOwner(creds, kuid) { if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || @@ -199,6 +236,20 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return nil } +// CheckDeleteSticky checks whether the sticky bit is set on a directory with +// the given file mode, and if so, checks whether creds has permission to +// remove a file owned by childKUID from a directory with the given mode. +// CheckDeleteSticky is consistent with fs/linux.h:check_sticky(). +func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error { + if parentMode&linux.ModeSticky == 0 { + return nil + } + if CanActAsOwner(creds, childKUID) { + return nil + } + return syserror.EPERM +} + // CanActAsOwner returns true if creds can act as the owner of a file with the // given owning UID, consistent with Linux's // fs/inode.c:inode_owner_or_capable(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 9acca8bc7..522e27475 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -24,6 +24,9 @@ // Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry // VirtualFilesystem.filesystemsMu // EpollInstance.mu +// Inotify.mu +// Watches.mu +// Inotify.evMu // VirtualFilesystem.fsTypesMu // // Locking Dentry.mu in multiple Dentries requires holding @@ -120,6 +123,9 @@ type VirtualFilesystem struct { // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. func (vfs *VirtualFilesystem) Init() error { + if vfs.mountpoints != nil { + panic("VFS already initialized") + } vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{}) vfs.devices = make(map[devTuple]*registeredDevice) vfs.anonBlockDevMinorNext = 1 diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD new file mode 100644 index 000000000..f08599ebd --- /dev/null +++ b/pkg/shim/runsc/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "runsc", + srcs = [ + "runsc.go", + "utils.go", + ], + visibility = ["//:sandbox"], + deps = [ + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go new file mode 100644 index 000000000..c5cf68efa --- /dev/null +++ b/pkg/shim/runsc/runsc.go @@ -0,0 +1,514 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "syscall" + "time" + + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +var Monitor runc.ProcessMonitor = runc.Monitor + +// DefaultCommand is the default command for Runsc. +const DefaultCommand = "runsc" + +// Runsc is the client to the runsc cli. +type Runsc struct { + Command string + PdeathSignal syscall.Signal + Setpgid bool + Root string + Log string + LogFormat runc.Format + Config map[string]string +} + +// List returns all containers created inside the provided runsc root directory. +func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { + data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + if err != nil { + return nil, err + } + var out []*runc.Container + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +// State returns the state for the container provided by id. +func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { + data, err := cmdOutput(r.command(context, "state", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var c runc.Container + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil +} + +type CreateOpts struct { + runc.IO + ConsoleSocket runc.ConsoleSocket + + // PidFile is a path to where a pid file should be created. + PidFile string + + // UserLog is a path to where runsc user log should be generated. + UserLog string +} + +func (o *CreateOpts) args() (out []string, err error) { + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.UserLog != "" { + out = append(out, "--user-log", o.UserLog) + } + return out, nil +} + +// Create creates a new container and returns its pid if it was created successfully. +func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error { + args := []string{"create", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +// Start will start an already created container. +func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { + cmd := r.command(context, "start", id) + if cio != nil { + cio.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if cio != nil { + if c, ok := cio.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +type waitResult struct { + ID string `json:"id"` + ExitStatus int `json:"exitStatus"` +} + +// Wait will wait for a running container, and return its exit status. +// +// TODO(random-liu): Add exec process support. +func (r *Runsc) Wait(context context.Context, id string) (int, error) { + data, err := cmdOutput(r.command(context, "wait", id), true) + if err != nil { + return 0, fmt.Errorf("%s: %s", err, data) + } + var res waitResult + if err := json.Unmarshal(data, &res); err != nil { + return 0, err + } + return res.ExitStatus, nil +} + +type ExecOpts struct { + runc.IO + PidFile string + InternalPidFile string + ConsoleSocket runc.ConsoleSocket + Detach bool +} + +func (o *ExecOpts) args() (out []string, err error) { + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.Detach { + out = append(out, "--detach") + } + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.InternalPidFile != "" { + abs, err := filepath.Abs(o.InternalPidFile) + if err != nil { + return nil, err + } + out = append(out, "--internal-pid-file", abs) + } + return out, nil +} + +// Exec executes an additional process inside the container based on a full OCI +// Process specification. +func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error { + f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process") + if err != nil { + return err + } + defer os.Remove(f.Name()) + err = json.NewEncoder(f).Encode(spec) + f.Close() + if err != nil { + return err + } + args := []string{"exec", "--process", f.Name()} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err +} + +// Run runs the create, start, delete lifecycle of the container and returns +// its exit status after it has exited. +func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) { + args := []string{"run", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return -1, err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + ec, err := Monitor.Start(cmd) + if err != nil { + return -1, err + } + return Monitor.Wait(cmd, ec) +} + +type DeleteOpts struct { + Force bool +} + +func (o *DeleteOpts) args() (out []string) { + if o.Force { + out = append(out, "--force") + } + return out +} + +// Delete deletes the container. +func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error { + args := []string{"delete"} + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id)...)) +} + +// KillOpts specifies options for killing a container and its processes. +type KillOpts struct { + All bool + Pid int +} + +func (o *KillOpts) args() (out []string) { + if o.All { + out = append(out, "--all") + } + if o.Pid != 0 { + out = append(out, "--pid", strconv.Itoa(o.Pid)) + } + return out +} + +// Kill sends the specified signal to the container. +func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error { + args := []string{ + "kill", + } + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...)) +} + +// Stats return the stats for a container like cpu, memory, and I/O. +func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { + cmd := r.command(context, "events", "--stats", id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + defer func() { + rd.Close() + Monitor.Wait(cmd, ec) + }() + var e runc.Event + if err := json.NewDecoder(rd).Decode(&e); err != nil { + return nil, err + } + return e.Stats, nil +} + +// Events returns an event stream from runsc for a container with stats and OOM notifications. +func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) { + cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + rd.Close() + return nil, err + } + var ( + dec = json.NewDecoder(rd) + c = make(chan *runc.Event, 128) + ) + go func() { + defer func() { + close(c) + rd.Close() + Monitor.Wait(cmd, ec) + }() + for { + var e runc.Event + if err := dec.Decode(&e); err != nil { + if err == io.EOF { + return + } + e = runc.Event{ + Type: "error", + Err: err, + } + } + c <- &e + } + }() + return c, nil +} + +// Ps lists all the processes inside the container returning their pids. +func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var pids []int + if err := json.Unmarshal(data, &pids); err != nil { + return nil, err + } + return pids, nil +} + +// Top lists all the processes inside the container returning the full ps data. +func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + + topResults, err := runc.ParsePSOutput(data) + if err != nil { + return nil, fmt.Errorf("%s: ", err) + } + return topResults, nil +} + +func (r *Runsc) args() []string { + var args []string + if r.Root != "" { + args = append(args, fmt.Sprintf("--root=%s", r.Root)) + } + if r.Log != "" { + args = append(args, fmt.Sprintf("--log=%s", r.Log)) + } + if r.LogFormat != "" { + args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat)) + } + for k, v := range r.Config { + args = append(args, fmt.Sprintf("--%s=%s", k, v)) + } + return args +} + +// runOrError will run the provided command. +// +// If an error is encountered and neither Stdout or Stderr was set the error +// will be returned in the format of <error>: <stderr>. +func (r *Runsc) runOrError(cmd *exec.Cmd) error { + if cmd.Stdout != nil || cmd.Stderr != nil { + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err + } + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil +} + +func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { + command := r.Command + if command == "" { + command = DefaultCommand + } + cmd := exec.CommandContext(context, command, append(r.args(), args...)...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: r.Setpgid, + } + if r.PdeathSignal != 0 { + cmd.SysProcAttr.Pdeathsig = r.PdeathSignal + } + + return cmd +} + +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { + b := getBuf() + defer putBuf(b) + + cmd.Stdout = b + if combined { + cmd.Stderr = b + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return b.Bytes(), err +} diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go new file mode 100644 index 000000000..c514b3bc7 --- /dev/null +++ b/pkg/shim/runsc/utils.go @@ -0,0 +1,44 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "bytes" + "strings" + "sync" +) + +var bytesBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +func getBuf() *bytes.Buffer { + return bytesBufferPool.Get().(*bytes.Buffer) +} + +func putBuf(b *bytes.Buffer) { + b.Reset() + bytesBufferPool.Put(b) +} + +// FormatLogPath parses runsc config, and fill in %ID% in the log path. +func FormatLogPath(id string, config map[string]string) { + if path, ok := config["debug-log"]; ok { + config["debug-log"] = strings.Replace(path, "%ID%", id, -1) + } +} diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD new file mode 100644 index 000000000..4377306af --- /dev/null +++ b/pkg/shim/v1/proc/BUILD @@ -0,0 +1,36 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "proc", + srcs = [ + "deleted_state.go", + "exec.go", + "exec_state.go", + "init.go", + "init_state.go", + "io.go", + "process.go", + "types.go", + "utils.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go new file mode 100644 index 000000000..d9b970c4d --- /dev/null +++ b/pkg/shim/v1/proc/deleted_state.go @@ -0,0 +1,49 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type deletedState struct{} + +func (*deletedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a deleted process.ss") +} + +func (*deletedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a deleted process.ss") +} + +func (*deletedState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { + return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) SetExited(status int) {} + +func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a deleted state") +} diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go new file mode 100644 index 000000000..1d1d90488 --- /dev/null +++ b/pkg/shim/v1/proc/exec.go @@ -0,0 +1,281 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +type execProcess struct { + wg sync.WaitGroup + + execState execState + + mu sync.Mutex + id string + console console.Console + io runc.IO + status int + exited time.Time + pid int + internalPid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + path string + spec specs.Process + + parent *Init + waitBlock chan struct{} +} + +func (e *execProcess) Wait() { + <-e.waitBlock +} + +func (e *execProcess) ID() string { + return e.id +} + +func (e *execProcess) Pid() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.pid +} + +func (e *execProcess) ExitStatus() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.status +} + +func (e *execProcess) ExitedAt() time.Time { + e.mu.Lock() + defer e.mu.Unlock() + return e.exited +} + +func (e *execProcess) SetExited(status int) { + e.mu.Lock() + defer e.mu.Unlock() + + e.execState.SetExited(status) +} + +func (e *execProcess) setExited(status int) { + e.status = status + e.exited = time.Now() + e.parent.Platform.ShutdownConsole(context.Background(), e.console) + close(e.waitBlock) +} + +func (e *execProcess) Delete(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Delete(ctx) +} + +func (e *execProcess) delete(ctx context.Context) error { + e.wg.Wait() + if e.io != nil { + for _, c := range e.closers { + c.Close() + } + e.io.Close() + } + pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + // silently ignore error + os.Remove(pidfile) + internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + // silently ignore error + os.Remove(internalPidfile) + return nil +} + +func (e *execProcess) Resize(ws console.WinSize) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Resize(ws) +} + +func (e *execProcess) resize(ws console.WinSize) error { + if e.console == nil { + return nil + } + return e.console.Resize(ws) +} + +func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Kill(ctx, sig, false) +} + +func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { + internalPid := e.internalPid + if internalPid != 0 { + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ + Pid: internalPid, + }); err != nil { + // If this returns error, consider the process has + // already stopped. + // + // TODO: Fix after signal handling is fixed. + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) + } + } + return nil +} + +func (e *execProcess) Stdin() io.Closer { + return e.stdin +} + +func (e *execProcess) Stdio() stdio.Stdio { + return e.stdio +} + +func (e *execProcess) Start(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Start(ctx) +} + +func (e *execProcess) start(ctx context.Context) (err error) { + var ( + socket *runc.Socket + pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + ) + if e.stdio.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create runc console socket: %w", err) + } + defer socket.Close() + } else if e.stdio.IsNull() { + if e.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil { + return fmt.Errorf("failed to create runc io pipes: %w", err) + } + } + opts := &runsc.ExecOpts{ + PidFile: pidfile, + InternalPidFile: internalPidfile, + IO: e.io, + Detach: true, + } + if socket != nil { + opts.ConsoleSocket = socket + } + eventCh := e.parent.Monitor.Subscribe() + defer func() { + // Unsubscribe if an error is returned. + if err != nil { + e.parent.Monitor.Unsubscribe(eventCh) + } + }() + if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil { + close(e.waitBlock) + return e.parent.runtimeError(err, "OCI runtime exec failed") + } + if e.stdio.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err) + } + e.closers = append(e.closers, sc) + e.stdin = sc + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + } else if !e.stdio.IsNull() { + if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(opts.PidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err) + } + e.pid = pid + internalPid, err := runc.ReadPidFile(opts.InternalPidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err) + } + e.internalPid = internalPid + go func() { + defer e.parent.Monitor.Unsubscribe(eventCh) + for event := range eventCh { + if event.Pid == e.pid { + ExitCh <- Exit{ + Timestamp: event.Timestamp, + ID: e.id, + Status: event.Status, + } + break + } + } + }() + return nil +} + +func (e *execProcess) Status(ctx context.Context) (string, error) { + e.mu.Lock() + defer e.mu.Unlock() + // if we don't have a pid then the exec process has just been created + if e.pid == 0 { + return "created", nil + } + // if we have a pid and it can be signaled, the process is running + // TODO(random-liu): Use `runsc kill --pid`. + if err := unix.Kill(e.pid, 0); err == nil { + return "running", nil + } + // else if we have a pid but it can nolonger be signaled, it has stopped + return "stopped", nil +} diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go new file mode 100644 index 000000000..4dcda8b44 --- /dev/null +++ b/pkg/shim/v1/proc/exec_state.go @@ -0,0 +1,154 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" +) + +type execState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type execCreatedState struct { + p *execProcess +} + +func (s *execCreatedState) transition(name string) error { + switch name { + case "running": + s.p.execState = &execRunningState{p: s.p} + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execCreatedState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execCreatedState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + return err + } + return s.transition("running") +} + +func (s *execCreatedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execCreatedState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execRunningState struct { + p *execProcess +} + +func (s *execRunningState) transition(name string) error { + switch name { + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execRunningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execRunningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process") +} + +func (s *execRunningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process") +} + +func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execRunningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execStoppedState struct { + p *execProcess +} + +func (s *execStoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execStoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *execStoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process") +} + +func (s *execStoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execStoppedState) SetExited(status int) { + // no op +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go new file mode 100644 index 000000000..dab3123d6 --- /dev/null +++ b/pkg/shim/v1/proc/init.go @@ -0,0 +1,460 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +// InitPidFile name of the file that contains the init pid. +const InitPidFile = "init.pid" + +// Init represents an initial process for a container. +type Init struct { + wg sync.WaitGroup + initState initState + + // mu is used to ensure that `Start()` and `Exited()` calls return in + // the right order when invoked in separate go routines. This is the + // case within the shim implementation as it makes use of the reaper + // interface. + mu sync.Mutex + + waitBlock chan struct{} + + WorkDir string + + id string + Bundle string + console console.Console + Platform stdio.Platform + io runc.IO + runtime *runsc.Runsc + status int + exited time.Time + pid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + Rootfs string + IoUID int + IoGID int + Sandbox bool + UserLog string + Monitor ProcessMonitor +} + +// NewRunsc returns a new runsc instance for a process. +func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc { + if root == "" { + root = RunscRoot + } + return &runsc.Runsc{ + Command: runtime, + PdeathSignal: syscall.SIGKILL, + Log: filepath.Join(path, "log.json"), + LogFormat: runc.JSON, + Root: filepath.Join(root, namespace), + Config: config, + } +} + +// New returns a new init process. +func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init { + p := &Init{ + id: id, + runtime: runtime, + stdio: stdio, + status: 0, + waitBlock: make(chan struct{}), + } + p.initState = &createdState{p: p} + return p +} + +// Create the process with the provided config. +func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) { + var socket *runc.Socket + if r.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create OCI runtime console socket: %w", err) + } + defer socket.Close() + } else if hasNoIO(r) { + if p.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil { + return fmt.Errorf("failed to create OCI runtime io pipes: %w", err) + } + } + pidFile := filepath.Join(p.Bundle, InitPidFile) + opts := &runsc.CreateOpts{ + PidFile: pidFile, + } + if socket != nil { + opts.ConsoleSocket = socket + } + if p.Sandbox { + opts.IO = p.io + // UserLog is only useful for sandbox. + opts.UserLog = p.UserLog + } + if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil { + return p.runtimeError(err, "OCI runtime create failed") + } + if r.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err) + } + p.stdin = sc + p.closers = append(p.closers, sc) + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg) + if err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + p.console = console + } else if !hasNoIO(r) { + if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(pidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err) + } + p.pid = pid + return nil +} + +// Wait waits for the process to exit. +func (p *Init) Wait() { + <-p.waitBlock +} + +// ID returns the ID of the process. +func (p *Init) ID() string { + return p.id +} + +// Pid returns the PID of the process. +func (p *Init) Pid() int { + return p.pid +} + +// ExitStatus returns the exit status of the process. +func (p *Init) ExitStatus() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.status +} + +// ExitedAt returns the time when the process exited. +func (p *Init) ExitedAt() time.Time { + p.mu.Lock() + defer p.mu.Unlock() + return p.exited +} + +// Status returns the status of the process. +func (p *Init) Status(ctx context.Context) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + c, err := p.runtime.State(ctx, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return "stopped", nil + } + return "", p.runtimeError(err, "OCI runtime state failed") + } + return p.convertStatus(c.Status), nil +} + +// Start starts the init process. +func (p *Init) Start(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Start(ctx) +} + +func (p *Init) start(ctx context.Context) error { + var cio runc.IO + if !p.Sandbox { + cio = p.io + } + if err := p.runtime.Start(ctx, p.id, cio); err != nil { + return p.runtimeError(err, "OCI runtime start failed") + } + go func() { + status, err := p.runtime.Wait(context.Background(), p.id) + if err != nil { + log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) + // TODO(random-liu): Handle runsc kill error. + if err := p.killAll(ctx); err != nil { + log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) + } + status = internalErrorCode + } + ExitCh <- Exit{ + Timestamp: time.Now(), + ID: p.id, + Status: status, + } + }() + return nil +} + +// SetExited set the exit stauts of the init process. +func (p *Init) SetExited(status int) { + p.mu.Lock() + defer p.mu.Unlock() + + p.initState.SetExited(status) +} + +func (p *Init) setExited(status int) { + p.exited = time.Now() + p.status = status + p.Platform.ShutdownConsole(context.Background(), p.console) + close(p.waitBlock) +} + +// Delete deletes the init process. +func (p *Init) Delete(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Delete(ctx) +} + +func (p *Init) delete(ctx context.Context) error { + p.killAll(ctx) + p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + err = nil + } else { + err = p.runtimeError(err, "failed to delete task") + } + } + if p.io != nil { + for _, c := range p.closers { + c.Close() + } + p.io.Close() + } + if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount") + if err == nil { + err = fmt.Errorf("failed rootfs umount: %w", err2) + } + } + return err +} + +// Resize resizes the init processes console. +func (p *Init) Resize(ws console.WinSize) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +func (p *Init) resize(ws console.WinSize) error { + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +// Kill kills the init process. +func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Kill(ctx, signal, all) +} + +func (p *Init) kill(context context.Context, signal uint32, all bool) error { + var ( + killErr error + backoff = 100 * time.Millisecond + ) + timeout := 1 * time.Second + for start := time.Now(); time.Now().Sub(start) < timeout; { + c, err := p.runtime.State(context, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + return p.runtimeError(err, "OCI runtime state failed") + } + // For runsc, signal only works when container is running state. + // If the container is not in running state, directly return + // "no such process" + if p.convertStatus(c.Status) == "stopped" { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ + All: all, + }) + if killErr == nil { + return nil + } + time.Sleep(backoff) + backoff *= 2 + } + return p.runtimeError(killErr, "kill timeout") +} + +// KillAll kills all processes belonging to the init process. +func (p *Init) KillAll(context context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + return p.killAll(context) +} + +func (p *Init) killAll(context context.Context) error { + p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{ + All: true, + }) + // Ignore error handling for `runsc kill --all` for now. + // * If it doesn't return error, it is good; + // * If it returns error, consider the container has already stopped. + // TODO: Fix `runsc kill --all` error handling. + return nil +} + +// Stdin returns the stdin of the process. +func (p *Init) Stdin() io.Closer { + return p.stdin +} + +// Runtime returns the OCI runtime configured for the init process. +func (p *Init) Runtime() *runsc.Runsc { + return p.runtime +} + +// Exec returns a new child process. +func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Exec(ctx, path, r) +} + +// exec returns a new exec'd process. +func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + // process exec request + var spec specs.Process + if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { + return nil, err + } + spec.Terminal = r.Terminal + + e := &execProcess{ + id: r.ID, + path: path, + parent: p, + spec: spec, + stdio: stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }, + waitBlock: make(chan struct{}), + } + e.execState = &execCreatedState{p: e} + return e, nil +} + +// Stdio returns the stdio of the process. +func (p *Init) Stdio() stdio.Stdio { + return p.stdio +} + +func (p *Init) runtimeError(rErr error, msg string) error { + if rErr == nil { + return nil + } + + rMsg, err := getLastRuntimeError(p.runtime) + switch { + case err != nil: + return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err) + case rMsg == "": + return fmt.Errorf("%s: %w", msg, rErr) + default: + return fmt.Errorf("%s: %s", msg, rMsg) + } +} + +func (p *Init) convertStatus(status string) string { + if status == "created" && !p.Sandbox && p.status == internalErrorCode { + // Treat start failure state for non-root container as stopped. + return "stopped" + } + return status +} + +func withConditionalIO(c stdio.Stdio) runc.IOOpt { + return func(o *runc.IOOption) { + o.OpenStdin = c.Stdin != "" + o.OpenStdout = c.Stdout != "" + o.OpenStderr = c.Stderr != "" + } +} diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go new file mode 100644 index 000000000..9233ecc85 --- /dev/null +++ b/pkg/shim/v1/proc/init_state.go @@ -0,0 +1,182 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type initState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Exec(context.Context, string, *ExecConfig) (process.Process, error) + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type createdState struct { + p *Init +} + +func (s *createdState) transition(name string) error { + switch name { + case "running": + s.p.initState = &runningState{p: s.p} + case "stopped": + s.p.initState = &stoppedState{p: s.p} + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *createdState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *createdState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + // Containerd doesn't allow deleting container in created state. + // However, for gvisor, a non-root container in created state can + // only go to running state. If the container can't be started, + // it can only stay in created state, and never be deleted. + // To work around that, we treat non-root container in start failure + // state as stopped. + if !s.p.Sandbox { + s.p.io.Close() + s.p.setExited(internalErrorCode) + if err := s.transition("stopped"); err != nil { + panic(err) + } + } + return err + } + return s.transition("running") +} + +func (s *createdState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *createdState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type runningState struct { + p *Init +} + +func (s *runningState) transition(name string) error { + switch name { + case "stopped": + s.p.initState = &stoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *runningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *runningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process.ss") +} + +func (s *runningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process.ss") +} + +func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *runningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type stoppedState struct { + p *Init +} + +func (s *stoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *stoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *stoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process.ss") +} + +func (s *stoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +} + +func (s *stoppedState) SetExited(status int) { + // no op +} + +func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a stopped state") +} diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go new file mode 100644 index 000000000..34d825fb7 --- /dev/null +++ b/pkg/shim/v1/proc/io.go @@ -0,0 +1,162 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "syscall" + + "github.com/containerd/containerd/log" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" +) + +// TODO(random-liu): This file can be a util. + +var bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, +} + +func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error { + var sameFile *countingWriteCloser + for _, i := range []struct { + name string + dest func(wc io.WriteCloser, rc io.Closer) + }{ + { + name: stdout, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil { + log.G(ctx).Warn("error copying stdout") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, { + name: stderr, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil { + log.G(ctx).Warn("error copying stderr") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, + } { + ok, err := isFifo(i.name) + if err != nil { + return err + } + var ( + fw io.WriteCloser + fr io.Closer + ) + if ok { + if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + } else { + if sameFile != nil { + sameFile.count++ + i.dest(sameFile, nil) + continue + } + if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if stdout == stderr { + sameFile = &countingWriteCloser{ + WriteCloser: fw, + count: 1, + } + } + } + i.dest(fw, fr) + } + if stdin == "" { + return nil + } + f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err) + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + + io.CopyBuffer(rio.Stdin(), f, *p) + rio.Stdin().Close() + f.Close() + }() + return nil +} + +// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times. +type countingWriteCloser struct { + io.WriteCloser + count int64 +} + +func (c *countingWriteCloser) Close() error { + if atomic.AddInt64(&c.count, -1) > 0 { + return nil + } + return c.WriteCloser.Close() +} + +// isFifo checks if a file is a fifo. +// +// If the file does not exist then it returns false. +func isFifo(path string) (bool, error) { + stat, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe { + return true, nil + } + return false, nil +} diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go new file mode 100644 index 000000000..d462c3eef --- /dev/null +++ b/pkg/shim/v1/proc/process.go @@ -0,0 +1,37 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" +) + +// RunscRoot is the path to the root runsc state directory. +const RunscRoot = "/run/containerd/runsc" + +func stateName(v interface{}) string { + switch v.(type) { + case *runningState, *execRunningState: + return "running" + case *createdState, *execCreatedState: + return "created" + case *deletedState: + return "deleted" + case *stoppedState: + return "stopped" + } + panic(fmt.Errorf("invalid state %v", v)) +} diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go new file mode 100644 index 000000000..2b0df4663 --- /dev/null +++ b/pkg/shim/v1/proc/types.go @@ -0,0 +1,69 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "time" + + runc "github.com/containerd/go-runc" + "github.com/gogo/protobuf/types" +) + +// Mount holds filesystem mount configuration. +type Mount struct { + Type string + Source string + Target string + Options []string +} + +// CreateConfig hold task creation configuration. +type CreateConfig struct { + ID string + Bundle string + Runtime string + Rootfs []Mount + Terminal bool + Stdin string + Stdout string + Stderr string + Options *types.Any +} + +// ExecConfig holds exec creation configuration. +type ExecConfig struct { + ID string + Terminal bool + Stdin string + Stdout string + Stderr string + Spec *types.Any +} + +// Exit is the type of exit events. +type Exit struct { + Timestamp time.Time + ID string + Status int +} + +// ProcessMonitor monitors process exit changes. +type ProcessMonitor interface { + // Subscribe to process exit changes + Subscribe() chan runc.Exit + // Unsubscribe to process exit changes + Unsubscribe(c chan runc.Exit) +} diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go new file mode 100644 index 000000000..716de2f59 --- /dev/null +++ b/pkg/shim/v1/proc/utils.go @@ -0,0 +1,90 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "encoding/json" + "io" + "os" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +const ( + internalErrorCode = 128 + bufferSize = 32 +) + +// ExitCh is the exit events channel for containers and exec processes +// inside the sandbox. +var ExitCh = make(chan Exit, bufferSize) + +// TODO(mlaventure): move to runc package? +func getLastRuntimeError(r *runsc.Runsc) (string, error) { + if r.Log == "" { + return "", nil + } + + f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400) + if err != nil { + return "", err + } + + var ( + errMsg string + log struct { + Level string + Msg string + Time time.Time + } + ) + + dec := json.NewDecoder(f) + for err = nil; err == nil; { + if err = dec.Decode(&log); err != nil && err != io.EOF { + return "", err + } + if log.Level == "error" { + errMsg = strings.TrimSpace(log.Msg) + } + } + + return errMsg, nil +} + +func copyFile(to, from string) error { + ff, err := os.Open(from) + if err != nil { + return err + } + defer ff.Close() + tt, err := os.Create(to) + if err != nil { + return err + } + defer tt.Close() + + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + _, err = io.CopyBuffer(tt, ff, *p) + return err +} + +func hasNoIO(r *CreateConfig) bool { + return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" +} diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD new file mode 100644 index 000000000..05c595bc9 --- /dev/null +++ b/pkg/shim/v1/shim/BUILD @@ -0,0 +1,40 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "shim", + srcs = [ + "api.go", + "platform.go", + "service.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + ], +) diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go new file mode 100644 index 000000000..5dd8ff172 --- /dev/null +++ b/pkg/shim/v1/shim/api.go @@ -0,0 +1,28 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskCreate = events.TaskCreate +type TaskStart = events.TaskStart +type TaskOOM = events.TaskOOM +type TaskExit = events.TaskExit +type TaskDelete = events.TaskDelete +type TaskExecAdded = events.TaskExecAdded +type TaskExecStarted = events.TaskExecStarted diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go new file mode 100644 index 000000000..f590f80ef --- /dev/null +++ b/pkg/shim/v1/shim/platform.go @@ -0,0 +1,106 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *Service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go new file mode 100644 index 000000000..84a810cb2 --- /dev/null +++ b/pkg/shim/v1/shim/service.go @@ -0,0 +1,573 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/containerd/console" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +// Config contains shim specific configuration. +type Config struct { + Path string + Namespace string + WorkDir string + RuntimeRoot string + RunscConfig map[string]string +} + +// NewService returns a new shim service that can be used via GRPC. +func NewService(config Config, publisher events.Publisher) (*Service, error) { + if config.Namespace == "" { + return nil, fmt.Errorf("shim namespace cannot be empty") + } + ctx := namespaces.WithNamespace(context.Background(), config.Namespace) + s := &Service{ + config: config, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + } + go s.processExits() + if err := s.initPlatform(); err != nil { + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// Service is the shim implementation of a remote shim over GRPC. +type Service struct { + mu sync.Mutex + + config Config + context context.Context + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + ec chan proc.Exit + + // Filled by Create() + id string + bundle string +} + +// Create creates a new initial process and container with the underlying OCI runtime. +func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: r.Runtime, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + defer func() { + if err != nil { + if err2 := mount.UnmountAll(rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount") + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + s.config.Path, + s.config.WorkDir, + s.config.RuntimeRoot, + s.config.Namespace, + s.config.RunscConfig, + s.platform, + config, + ) + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + pid := process.Pid() + s.processes[r.ID] = process + return &shim.CreateTaskResponse{ + Pid: uint32(pid), + }, nil +} + +// Start starts a process. +func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + return &shim.StartResponse{ + ID: p.ID(), + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, s.id) + s.mu.Unlock() + s.platform.Close() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// DeleteProcess deletes an exec'd process. +func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) { + if r.ID == s.id { + return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess") + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, r.ID) + s.mu.Unlock() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + + if p := s.processes[r.ID]; p != nil { + s.mu.Unlock() + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID) + } + + p := s.processes[s.id] + s.mu.Unlock() + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + + process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{ + ID: r.ID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resises the terminal of a process. +func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) { + if r.ID == "" { + return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided") + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &shim.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause pauses the container. +func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume resumes the container. +func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill kills a process with the provided signal. +func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) { + if r.ID == "" { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil + } + + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// ListPids returns all pids inside the container. +func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &shim.ListPidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// ShimInfo returns shim information such as the shim's pid. +func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) { + return &shim.ShimInfoResponse{ + ShimPid: uint32(os.Getpid()), + }, nil +} + +// Update updates a running container. +func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + p.Wait() + + return &shim.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *Service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *Service) allProcesses() []process.Process { + s.mu.Lock() + defer s.mu.Unlock() + + res := make([]process.Process, 0, len(s.processes)) + for _, p := range s.processes { + res = append(res, p) + } + return res +} + +func (s *Service) checkProcesses(e proc.Exit) { + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *Service) forward(publisher events.Publisher) { + for e := range s.events { + if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil { + log.G(s.context).WithError(err).Error("post event") + } + } +} + +// getInitProcess returns the init process. +func (s *Service) getInitProcess() (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[s.id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + return p, nil +} + +// getExecProcess returns the given exec process. +func (s *Service) getExecProcess(id string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id) + } + return p, nil +} + +func getTopic(ctx context.Context, e interface{}) string { + switch e.(type) { + case *TaskCreate: + return runtime.TaskCreateEventTopic + case *TaskStart: + return runtime.TaskStartEventTopic + case *TaskOOM: + return runtime.TaskOOMEventTopic + case *TaskExit: + return runtime.TaskExitEventTopic + case *TaskDelete: + return runtime.TaskDeleteEventTopic + case *TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { + var options runctypes.CreateOptions + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + options = *v.(*runctypes.CreateOptions) + } + + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + + runsc.FormatLogPath(r.ID, config) + rootfs := filepath.Join(path, "rootfs") + runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = utils.IsSandbox(spec) + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD new file mode 100644 index 000000000..54a0aabb7 --- /dev/null +++ b/pkg/shim/v1/utils/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "utils", + srcs = [ + "annotations.go", + "utils.go", + "volumes.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "utils_test", + size = "small", + srcs = ["volumes_test.go"], + library = ":utils", + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], +) diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go new file mode 100644 index 000000000..1e9d3f365 --- /dev/null +++ b/pkg/shim/v1/utils/annotations.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// Annotations from the CRI annotations package. +// +// These are vendor due to import conflicts. +const ( + sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" + containerTypeAnnotation = "io.kubernetes.cri.container-type" + containerTypeSandbox = "sandbox" + containerTypeContainer = "container" +) diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go new file mode 100644 index 000000000..07e346654 --- /dev/null +++ b/pkg/shim/v1/utils/utils.go @@ -0,0 +1,56 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +// ReadSpec reads OCI spec from the bundle directory. +func ReadSpec(bundle string) (*specs.Spec, error) { + f, err := os.Open(filepath.Join(bundle, "config.json")) + if err != nil { + return nil, err + } + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + return nil, err + } + return &spec, nil +} + +// IsSandbox checks whether a container is a sandbox container. +func IsSandbox(spec *specs.Spec) bool { + t, ok := spec.Annotations[containerTypeAnnotation] + return !ok || t == containerTypeSandbox +} + +// UserLogPath gets user log path from OCI annotation. +func UserLogPath(spec *specs.Spec) string { + sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "" + } + return filepath.Join(sandboxLogDir, "gvisor.log") +} diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go new file mode 100644 index 000000000..52a428179 --- /dev/null +++ b/pkg/shim/v1/utils/volumes.go @@ -0,0 +1,155 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const volumeKeyPrefix = "dev.gvisor.spec.mount." + +var kubeletPodsDir = "/var/lib/kubelet/pods" + +// volumeName gets volume name from volume annotation key, example: +// dev.gvisor.spec.mount.NAME.share +func volumeName(k string) string { + return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0] +} + +// volumeFieldName gets volume field name from volume annotation key, example: +// `type` is the field of dev.gvisor.spec.mount.NAME.type +func volumeFieldName(k string) string { + parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".") + return parts[len(parts)-1] +} + +// podUID gets pod UID from the pod log path. +func podUID(s *specs.Spec) (string, error) { + sandboxLogDir := s.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "", fmt.Errorf("no sandbox log path annotation") + } + fields := strings.Split(filepath.Base(sandboxLogDir), "_") + switch len(fields) { + case 1: // This is the old CRI logging path. + return fields[0], nil + case 3: // This is the new CRI logging path. + return fields[2], nil + } + return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir) +} + +// isVolumeKey checks whether an annotation key is for volume. +func isVolumeKey(k string) bool { + return strings.HasPrefix(k, volumeKeyPrefix) +} + +// volumeSourceKey constructs the annotation key for volume source. +func volumeSourceKey(volume string) string { + return volumeKeyPrefix + volume + ".source" +} + +// volumePath searches the volume path in the kubelet pod directory. +func volumePath(volume, uid string) (string, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume) + dirs, err := filepath.Glob(volumeSearchPath) + if err != nil { + return "", err + } + if len(dirs) != 1 { + return "", fmt.Errorf("unexpected matched volume list %v", dirs) + } + return dirs[0], nil +} + +// isVolumePath checks whether a string is the volume path. +func isVolumePath(volume, path string) (bool, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume) + return filepath.Match(volumeSearchPath, path) +} + +// UpdateVolumeAnnotations add necessary OCI annotations for gvisor +// volume optimization. +func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { + var ( + uid string + err error + ) + if IsSandbox(s) { + uid, err = podUID(s) + if err != nil { + // Skip if we can't get pod UID, because this doesn't work + // for containerd 1.1. + return nil + } + } + var updated bool + for k, v := range s.Annotations { + if !isVolumeKey(k) { + continue + } + if volumeFieldName(k) != "type" { + continue + } + volume := volumeName(k) + if uid != "" { + // This is a sandbox. + path, err := volumePath(volume, uid) + if err != nil { + return fmt.Errorf("get volume path for %q: %w", volume, err) + } + s.Annotations[volumeSourceKey(volume)] = path + updated = true + } else { + // This is a container. + for i := range s.Mounts { + // An error is returned for sandbox if source + // annotation is not successfully applied, so + // it is guaranteed that the source annotation + // for sandbox has already been successfully + // applied at this point. + // + // The volume name is unique inside a pod, so + // matching without podUID is fine here. + // + // TODO: Pass podUID down to shim for containers to do + // more accurate matching. + if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { + // gVisor requires the container mount type to match + // sandbox mount type. + s.Mounts[i].Type = v + updated = true + } + } + } + } + if !updated { + return nil + } + // Update bundle. + b, err := json.Marshal(s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) +} diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go new file mode 100644 index 000000000..3e02c6151 --- /dev/null +++ b/pkg/shim/v1/utils/volumes_test.go @@ -0,0 +1,308 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +func TestUpdateVolumeAnnotations(t *testing.T) { + dir, err := ioutil.TempDir("", "test-update-volume-annotations") + if err != nil { + t.Fatalf("create tempdir: %v", err) + } + defer os.RemoveAll(dir) + kubeletPodsDir = dir + + const ( + testPodUID = "testuid" + testVolumeName = "testvolume" + testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID + testLegacyLogDirPath = "/var/log/pods/" + testPodUID + ) + testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName) + + if err := os.MkdirAll(testVolumePath, 0755); err != nil { + t.Fatalf("Create test volume: %v", err) + } + + for _, test := range []struct { + desc string + spec *specs.Spec + expected *specs.Spec + expectErr bool + expectUpdate bool + }{ + { + desc: "volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "volume annotations for sandbox with legacy log path", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "tmpfs: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "bind: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "should not return error without pod log directory", + spec: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + }, + { + desc: "should return error if volume path does not exist", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount.notexist.share": "pod", + "dev.gvisor.spec.mount.notexist.type": "tmpfs", + "dev.gvisor.spec.mount.notexist.options": "ro", + }, + }, + expectErr: true, + }, + { + desc: "no volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + }, + { + desc: "no volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + bundle, err := ioutil.TempDir(dir, "test-bundle") + if err != nil { + t.Fatalf("Create test bundle: %v", err) + } + err = UpdateVolumeAnnotations(bundle, test.spec) + if test.expectErr { + if err == nil { + t.Fatal("Expected error, but got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(test.expected, test.spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) + } + if test.expectUpdate { + b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) + if err != nil { + t.Fatalf("Read spec from bundle: %v", err) + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + t.Fatalf("Unmarshal spec: %v", err) + } + if !reflect.DeepEqual(test.expected, &spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, &spec) + } + } + }) + } +} diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD new file mode 100644 index 000000000..7e0a114a0 --- /dev/null +++ b/pkg/shim/v2/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "v2", + srcs = [ + "api.go", + "epoll.go", + "service.go", + "service_linux.go", + ], + visibility = ["//shim:__subpackages__"], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "//pkg/shim/v2/options", + "//pkg/shim/v2/runtimeoptions", + "//runsc/specutils", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + "@com_github_containerd_containerd//runtime/v2/task:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go new file mode 100644 index 000000000..dbe5c59f6 --- /dev/null +++ b/pkg/shim/v2/api.go @@ -0,0 +1,22 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskOOM = events.TaskOOM diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go new file mode 100644 index 000000000..41232cca8 --- /dev/null +++ b/pkg/shim/v2/epoll.go @@ -0,0 +1,129 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "sync" + + "github.com/containerd/cgroups" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/runtime" + "golang.org/x/sys/unix" +) + +func newOOMEpoller(publisher events.Publisher) (*epoller, error) { + fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + return &epoller{ + fd: fd, + publisher: publisher, + set: make(map[uintptr]*item), + }, nil +} + +type epoller struct { + mu sync.Mutex + + fd int + publisher events.Publisher + set map[uintptr]*item +} + +type item struct { + id string + cg cgroups.Cgroup +} + +func (e *epoller) Close() error { + return unix.Close(e.fd) +} + +func (e *epoller) run(ctx context.Context) { + var events [128]unix.EpollEvent + for { + select { + case <-ctx.Done(): + e.Close() + return + default: + n, err := unix.EpollWait(e.fd, events[:], -1) + if err != nil { + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + // Should not happen. + panic(fmt.Errorf("cgroups: epoll wait: %w", err)) + } + for i := 0; i < n; i++ { + e.process(ctx, uintptr(events[i].Fd)) + } + } + } +} + +func (e *epoller) add(id string, cg cgroups.Cgroup) error { + e.mu.Lock() + defer e.mu.Unlock() + fd, err := cg.OOMEventFD() + if err != nil { + return err + } + e.set[fd] = &item{ + id: id, + cg: cg, + } + event := unix.EpollEvent{ + Fd: int32(fd), + Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR, + } + return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event) +} + +func (e *epoller) process(ctx context.Context, fd uintptr) { + flush(fd) + e.mu.Lock() + i, ok := e.set[fd] + if !ok { + e.mu.Unlock() + return + } + e.mu.Unlock() + if i.cg.State() == cgroups.Deleted { + e.mu.Lock() + delete(e.set, fd) + e.mu.Unlock() + unix.Close(int(fd)) + return + } + if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{ + ContainerID: i.id, + }); err != nil { + // Should not happen. + panic(fmt.Errorf("publish OOM event: %w", err)) + } +} + +func flush(fd uintptr) error { + var buf [8]byte + _, err := unix.Read(int(fd), buf[:]) + return err +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD new file mode 100644 index 000000000..ca212e874 --- /dev/null +++ b/pkg/shim/v2/options/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "options", + srcs = [ + "options.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go new file mode 100644 index 000000000..de09f2f79 --- /dev/null +++ b/pkg/shim/v2/options/options.go @@ -0,0 +1,33 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +const OptionType = "io.containerd.runsc.v1.options" + +// Options is runtime options for io.containerd.runsc.v1. +type Options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup"` + // IoUid is the I/O's pipes uid. + IoUid uint32 `toml:"io_uid"` + // IoUid is the I/O's pipes gid. + IoGid uint32 `toml:"io_gid"` + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name"` + // Root is the runsc root directory. + Root string `toml:"root"` + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config"` +} diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD new file mode 100644 index 000000000..01716034c --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "proto_library") + +package(licenses = ["notice"]) + +proto_library( + name = "api", + srcs = [ + "runtimeoptions.proto", + ], +) + +go_library( + name = "runtimeoptions", + srcs = ["runtimeoptions.go"], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = [ + "//pkg/shim/v2/runtimeoptions:api_go_proto", + "@com_github_gogo_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go new file mode 100644 index 000000000..1c1a0c5d1 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -0,0 +1,27 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + proto "github.com/gogo/protobuf/proto" + pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" +) + +type Options = pb.Options + +func init() { + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto new file mode 100644 index 000000000..edb19020a --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package runtimeoptions; + +// This is a version of the runtimeoptions CRI API that is vendored. +// +// Imported the full CRI package is a nightmare. +message Options { + string type_url = 1; + string config_path = 2; +} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go new file mode 100644 index 000000000..1534152fc --- /dev/null +++ b/pkg/shim/v2/service.go @@ -0,0 +1,824 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/BurntSushi/toml" + "github.com/containerd/cgroups" + "github.com/containerd/console" + "github.com/containerd/containerd/api/events" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + "github.com/containerd/containerd/runtime/v2/shim" + taskAPI "github.com/containerd/containerd/runtime/v2/task" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" + "gvisor.dev/gvisor/pkg/shim/v2/options" + "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" + "gvisor.dev/gvisor/runsc/specutils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +var _ = (taskAPI.TaskService)(&service{}) + +// configFile is the default config file name. For containerd 1.2, +// we assume that a config.toml should exist in the runtime root. +const configFile = "config.toml" + +// New returns a new shim service that can be used via GRPC. +func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + ep, err := newOOMEpoller(publisher) + if err != nil { + return nil, err + } + go ep.run(ctx) + s := &service{ + id: id, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + } + go s.processExits() + runsc.Monitor = reaper.Default + if err := s.initPlatform(); err != nil { + cancel() + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// service is the shim implementation of a remote shim over GRPC. +type service struct { + mu sync.Mutex + + context context.Context + task process.Process + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + opts options.Options + ec chan proc.Exit + oomPoller *epoller + + id string + bundle string + cancel func() +} + +func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + args := []string{ + "-namespace", ns, + "-address", containerdAddress, + "-publish-binary", containerdBinary, + } + cmd := exec.Command(self, args...) + cmd.Dir = cwd + cmd.Env = append(os.Environ(), "GOMAXPROCS=2") + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + return cmd, nil +} + +func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { + cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + if err != nil { + return "", err + } + address, err := shim.SocketAddress(ctx, id) + if err != nil { + return "", err + } + socket, err := shim.NewSocket(address) + if err != nil { + return "", err + } + defer socket.Close() + f, err := socket.File() + if err != nil { + return "", err + } + defer f.Close() + + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + + if err := cmd.Start(); err != nil { + return "", err + } + defer func() { + if err != nil { + cmd.Process.Kill() + } + }() + // make sure to wait after start + go cmd.Wait() + if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { + return "", err + } + if err := shim.WriteAddress("address", address); err != nil { + return "", err + } + if err := shim.SetScore(cmd.Process.Pid); err != nil { + return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) + } + return address, nil +} + +func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ + Force: true, + }); err != nil { + log.L.Printf("failed to remove runc container: %v", err) + } + if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + return &taskAPI.DeleteResponse{ + ExitedAt: time.Now(), + ExitStatus: 128 + uint32(unix.SIGKILL), + }, nil +} + +func (s *service) readRuntime(path string) (string, error) { + data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *service) writeRuntime(path, runtime string) error { + return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) +} + +// Create creates a new initial process and container with the underlying OCI +// runtime. +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, fmt.Errorf("create namespace: %w", err) + } + + // Read from root for now. + var opts options.Options + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + var path string + switch o := v.(type) { + case *runctypes.CreateOptions: // containerd 1.2.x + opts.IoUid = o.IoUid + opts.IoGid = o.IoGid + opts.ShimCgroup = o.ShimCgroup + case *runctypes.RuncOptions: // containerd 1.2.x + root := proc.RunscRoot + if o.RuntimeRoot != "" { + root = o.RuntimeRoot + } + + opts.BinaryName = o.Runtime + + path = filepath.Join(root, configFile) + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("stat config file %q: %w", path, err) + } + // A config file in runtime root is not required. + path = "" + } + case *runtimeoptions.Options: // containerd 1.3.x+ + if o.ConfigPath == "" { + break + } + if o.TypeUrl != options.OptionType { + return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) + } + path = o.ConfigPath + default: + return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) + } + if path != "" { + if _, err = toml.DecodeFile(path, &opts); err != nil { + return nil, fmt.Errorf("decode config file %q: %w", path, err) + } + } + } + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: opts.BinaryName, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { + return nil, err + } + defer func() { + if err != nil { + if err := mount.UnmountAll(rootfs, 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + r.Bundle, + filepath.Join(r.Bundle, "work"), + ns, + s.platform, + config, + &opts, + rootfs, + ) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + + // Set up OOM notification on the sandbox's cgroup. This is done on + // sandbox create since the sandbox process will be created here. + pid := process.Pid() + if pid > 0 { + cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid)) + if err != nil { + return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err) + } + if err := s.oomPoller.add(s.id, cg); err != nil { + return nil, fmt.Errorf("add cg to OOM monitor: %w", err) + } + } + s.task = process + s.opts = opts + return &taskAPI.CreateTaskResponse{ + Pid: uint32(process.Pid()), + }, nil + +} + +// Start starts a process. +func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + // TODO: Set the cgroup and oom notifications on restore. + // https://github.com/google/gvisor-containerd-shim/issues/58 + return &taskAPI.StartResponse{ + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + isTask := r.ExecID == "" + if !isTask { + s.mu.Lock() + delete(s.processes, r.ExecID) + s.mu.Unlock() + } + if isTask && s.platform != nil { + s.platform.Close() + } + return &taskAPI.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + p := s.processes[r.ExecID] + s.mu.Unlock() + if p != nil { + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) + } + p = s.task + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + ID: r.ExecID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ExecID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resizes the terminal of a process. +func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &taskAPI.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause the container. +func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume the container. +func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill a process with the provided signal. +func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// Pids returns all pids inside the container. +func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &taskAPI.PidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Connect returns shim information such as the shim's pid. +func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + var pid int + if s.task != nil { + pid = s.task.Pid() + } + return &taskAPI.ConnectResponse{ + ShimPid: uint32(os.Getpid()), + TaskPid: uint32(pid), + }, nil +} + +func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + s.cancel() + os.Exit(0) + return empty, nil +} + +func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + stats, err := rs.Stats(ctx, s.id) + if err != nil { + return nil, err + } + + // gvisor currently (as of 2020-03-03) only returns the total memory + // usage and current PID value[0]. However, we copy the common fields here + // so that future updates will propagate correct information. We're + // using the cgroups.Metrics structure so we're returning the same type + // as runc. + // + // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 + data, err := typeurl.MarshalAny(&cgroups.Metrics{ + CPU: &cgroups.CPUStat{ + Usage: &cgroups.CPUUsage{ + Total: stats.Cpu.Usage.Total, + Kernel: stats.Cpu.Usage.Kernel, + User: stats.Cpu.Usage.User, + PerCPU: stats.Cpu.Usage.Percpu, + }, + Throttling: &cgroups.Throttle{ + Periods: stats.Cpu.Throttling.Periods, + ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, + ThrottledTime: stats.Cpu.Throttling.ThrottledTime, + }, + }, + Memory: &cgroups.MemoryStat{ + Cache: stats.Memory.Cache, + Usage: &cgroups.MemoryEntry{ + Limit: stats.Memory.Usage.Limit, + Usage: stats.Memory.Usage.Usage, + Max: stats.Memory.Usage.Max, + Failcnt: stats.Memory.Usage.Failcnt, + }, + Swap: &cgroups.MemoryEntry{ + Limit: stats.Memory.Swap.Limit, + Usage: stats.Memory.Swap.Usage, + Max: stats.Memory.Swap.Max, + Failcnt: stats.Memory.Swap.Failcnt, + }, + Kernel: &cgroups.MemoryEntry{ + Limit: stats.Memory.Kernel.Limit, + Usage: stats.Memory.Kernel.Usage, + Max: stats.Memory.Kernel.Max, + Failcnt: stats.Memory.Kernel.Failcnt, + }, + KernelTCP: &cgroups.MemoryEntry{ + Limit: stats.Memory.KernelTCP.Limit, + Usage: stats.Memory.KernelTCP.Usage, + Max: stats.Memory.KernelTCP.Max, + Failcnt: stats.Memory.KernelTCP.Failcnt, + }, + }, + Pids: &cgroups.PidsStat{ + Current: stats.Pids.Current, + Limit: stats.Pids.Limit, + }, + }) + if err != nil { + return nil, err + } + return &taskAPI.StatsResponse{ + Stats: data, + }, nil +} + +// Update updates a running container. +func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + p.Wait() + + return &taskAPI.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *service) checkProcesses(e proc.Exit) { + // TODO(random-liu): Add `shouldKillAll` logic if container pid + // namespace is supported. + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &events.TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *service) allProcesses() (o []process.Process) { + s.mu.Lock() + defer s.mu.Unlock() + for _, p := range s.processes { + o = append(o, p) + } + if s.task != nil { + o = append(o, s.task) + } + return o +} + +func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + s.mu.Lock() + p := s.task + s.mu.Unlock() + if p == nil { + return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) + } + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *service) forward(publisher shim.Publisher) { + for e := range s.events { + ctx, cancel := context.WithTimeout(s.context, 5*time.Second) + err := publisher.Publish(ctx, getTopic(e), e) + cancel() + if err != nil { + // Should not happen. + panic(fmt.Errorf("post event: %w", err)) + } + } +} + +func (s *service) getProcess(execID string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + if execID == "" { + return s.task, nil + } + p := s.processes[execID] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) + } + return p, nil +} + +func getTopic(e interface{}) string { + switch e.(type) { + case *events.TaskCreate: + return runtime.TaskCreateEventTopic + case *events.TaskStart: + return runtime.TaskStartEventTopic + case *events.TaskOOM: + return runtime.TaskOOMEventTopic + case *events.TaskExit: + return runtime.TaskExitEventTopic + case *events.TaskDelete: + return runtime.TaskDeleteEventTopic + case *events.TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *events.TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + runsc.FormatLogPath(r.ID, options.RunscConfig) + runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go new file mode 100644 index 000000000..1800ab90b --- /dev/null +++ b/pkg/shim/v2/service_linux.go @@ -0,0 +1,108 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index e131455f7..ae0fe1522 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -12,6 +12,7 @@ go_library( "sleep_unsafe.go", ], visibility = ["//:sandbox"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index af47e2ba1..1dd11707d 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -379,10 +379,7 @@ func TestRace(t *testing.T) { // TestRaceInOrder tests that multiple wakers can continuously send wake requests to // the sleeper and that the wakers are retrieved in the order asserted. func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) + w := make([]Waker, 10000) s := Sleeper{} // Associate each waker and start goroutines that will assert them. @@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) { s.AddWaker(&w[i], i) } go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ + for i := range w { + w[i].Assert() } }() // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) + for want := range w { + got, _ := s.Fetch(true) + if got != want { + t.Fatalf("got %d want %d", got, want) } } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index f68c12620..118805492 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -75,6 +75,8 @@ package sleep import ( "sync/atomic" "unsafe" + + "gvisor.dev/gvisor/pkg/sync" ) const ( @@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { // // This struct is thread-safe, that is, its methods can be called concurrently // by multiple goroutines. +// +// Note, it is not safe to copy a Waker as its fields are modified by value +// (the pointer fields are individually modified with atomic operations). type Waker struct { + _ sync.NoCopy + // s is the sleeper that this waker can wake up. Only one sleeper at a // time is allowed. This field can have three classes of values: // nil -- the waker is not asserted: it either is not associated with diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 2b1350135..089b3bbef 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,9 +1,47 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( + name = "pending_list", + out = "pending_list.go", + package = "state", + prefix = "pending", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "pendingMapper", + "Linker": "*pendingEntry", + }, +) + +go_template_instance( + name = "deferred_list", + out = "deferred_list.go", + package = "state", + prefix = "deferred", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "deferredMapper", + "Linker": "*deferredEntry", + }, +) + +go_template_instance( + name = "complete_list", + out = "complete_list.go", + package = "state", + prefix = "complete", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectDecodeState", + "Linker": "*objectDecodeState", + }, +) + +go_template_instance( name = "addr_range", out = "addr_range.go", package = "state", @@ -29,7 +67,7 @@ go_template_instance( types = { "Key": "uintptr", "Range": "addrRange", - "Value": "reflect.Value", + "Value": "*objectEncodeState", "Functions": "addrSetFunctions", }, ) @@ -39,32 +77,24 @@ go_library( srcs = [ "addr_range.go", "addr_set.go", + "complete_list.go", "decode.go", + "decode_unsafe.go", + "deferred_list.go", "encode.go", "encode_unsafe.go", - "map.go", - "printer.go", + "pending_list.go", "state.go", + "state_norace.go", + "state_race.go", "stats.go", + "types.go", ], marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ - ":object_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", + "//pkg/log", + "//pkg/state/wire", ], ) - -proto_library( - name = "object", - srcs = ["object.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "state_test", - timeout = "long", - srcs = ["state_test.go"], - library = ":state", -) diff --git a/pkg/state/README.md b/pkg/state/README.md new file mode 100644 index 000000000..1aa401193 --- /dev/null +++ b/pkg/state/README.md @@ -0,0 +1,158 @@ +# State Encoding and Decoding + +The state package implements the encoding and decoding of data structures for +`go_stateify`. This package is designed for use cases other than the standard +encoding packages, e.g. `gob` and `json`. Principally: + +* This package operates on complex object graphs and accurately serializes and + restores all relationships. That is, you can have things like: intrusive + pointers, cycles, and pointer chains of arbitrary depths. These are not + handled appropriately by existing encoders. This is not an implementation + flaw: the formats themselves are not capable of representing these graphs, + as they can only generate directed trees. + +* This package allows installing order-dependent load callbacks and then + resolves that graph at load time, with cycle detection. Similarly, there is + no analogous feature possible in the standard encoders. + +* This package handles the resolution of interfaces, based on a registered + type name. For interface objects type information is saved in the serialized + format. This is generally true for `gob` as well, but it works differently. + +Here's an overview of how encoding and decoding works. + +## Encoding + +Encoding produces a `statefile`, which contains a list of chunks of the form +`(header, payload)`. The payload can either be some raw data, or a series of +encoded wire objects representing some object graph. All encoded objects are +defined in the `wire` subpackage. + +Encoding of an object graph begins with `encodeState.Save`. + +### 1. Memory Map & Encoding + +To discover relationships between potentially interdependent data structures +(for example, a struct may contain pointers to members of other data +structures), the encoder first walks the object graph and constructs a memory +map of the objects in the input graph. As this walk progresses, objects are +queued in the `pending` list and items are placed on the `deferred` list as they +are discovered. No single object will be encoded multiple times, but the +discovered relationships between objects may change as more parts of the overall +object graph are discovered. + +The encoder starts at the root object and recursively visits all reachable +objects, recording the address ranges containing the underlying data for each +object. This is stored as a segment set (`addrSet`), mapping address ranges to +the of the object occupying the range; see `encodeState.values`. Note that there +is special handling for zero-sized types and map objects during this process. + +Additionally, the encoder assigns each object a unique identifier which is used +to indicate relationships between objects in the statefile; see `objectID` in +`encode.go`. + +### 2. Type Serialization + +The enoder will subsequently serialize all information about discovered types, +including field names. These are used during decoding to reconcile these types +with other internally registered types. + +### 3. Object Serialization + +With a full address map, and all objects correctly encoded, all object encodings +are serialized. The assigned `objectID`s aren't explicitly encoded in the +statefile. The order of object messages in the stream determine their IDs. + +### Example + +Given the following data structure definitions: + +```go +type system struct { + o *outer + i *inner +} + +type outer struct { + a int64 + cn *container +} + +type container struct { + n uint64 + elem *inner +} + +type inner struct { + c container + x, y uint64 +} +``` + +Initialized like this: + +```go +o := outer{ + a: 10, + cn: nil, +} +i := inner{ + x: 20, + y: 30, + c: container{}, +} +s := system{ + o: &o, + i: &i, +} + +o.cn = &i.c +o.cn.elem = &i + +``` + +Encoding will produce an object stream like this: + +``` +g0r1 = struct{ + i: g0r3, + o: g0r2, +} +g0r2 = struct{ + a: 10, + cn: g0r3.c, +} +g0r3 = struct{ + c: struct{ + elem: g0r3, + n: 0u, + }, + x: 20u, + y: 30u, +} +``` + +Note how `g0r3.c` is correctly encoded as the underlying `container` object for +`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i` +being discovered after the pointer to it in `system.o.cn`. Also note that +decoding isn't strictly reliant on the order of encoded object stream, as long +as the relationship between objects are correctly encoded. + +## Decoding + +Decoding reads the statefile and reconstructs the object graph. Decoding begins +in `decodeState.Load`. Decoding is performed in a single pass over the object +stream in the statefile, and a subsequent pass over all deserialized objects is +done to fire off all loading callbacks in the correctly defined order. Note that +introducing cycles is possible here, but these are detected and an error will be +returned. + +Decoding is relatively straight forward. For most primitive values, the decoder +constructs an appropriate object and fills it with the values encoded in the +statefile. Pointers need special handling, as they must point to a value +allocated elsewhere. When values are constructed, the decoder indexes them by +their `objectID`s in `decodeState.objectsByID`. The target of pointers are +resolved by searching for the target in this index by their `objectID`; see +`decodeState.register`. For pointers to values inside another value (fields in a +pointer, elements of an array), the decoder uses the accessor path to walk to +the appropriate location; see `walkChild`. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 590c241a3..c9971cdf6 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -17,28 +17,49 @@ package state import ( "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" + "math" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// objectState represents an object that may be in the process of being +// internalCallback is a interface called on object completion. +// +// There are two implementations: objectDecodeState & userCallback. +type internalCallback interface { + // source returns the dependent object. May be nil. + source() *objectDecodeState + + // callbackRun executes the callback. + callbackRun() +} + +// userCallback is an implementation of internalCallback. +type userCallback func() + +// source implements internalCallback.source. +func (userCallback) source() *objectDecodeState { + return nil +} + +// callbackRun implements internalCallback.callbackRun. +func (uc userCallback) callbackRun() { + uc() +} + +// objectDecodeState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. -type objectState struct { +type objectDecodeState struct { // id is the id for this object. - // - // If this field is zero, then this is an anonymous (unregistered, - // non-reference primitive) object. This is immutable. - id uint64 + id objectID + + // typ is the id for this typeID. This may be zero if this is not a + // type-registered structure. + typ typeID // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the @@ -57,69 +78,52 @@ type objectState struct { // blockedBy is the number of dependencies this object has. blockedBy int - // blocking is a list of the objects blocked by this one. - blocking []*objectState + // callbacksInline is inline storage for callbacks. + callbacksInline [2]internalCallback // callbacks is a set of callbacks to execute on load. - callbacks []func() - - // path is the decoding path to the object. - path recoverable -} - -// complete indicates the object is complete. -func (os *objectState) complete() bool { - return os.blockedBy == 0 && len(os.callbacks) == 0 -} - -// checkComplete checks for completion. If the object is complete, pending -// callbacks will be executed and checkComplete will be called on downstream -// objects (those depending on this one). -func (os *objectState) checkComplete(stats *Stats) { - if os.blockedBy > 0 { - return - } - stats.Start(os.obj) + callbacks []internalCallback - // Fire all callbacks. - for _, fn := range os.callbacks { - fn() - } - os.callbacks = nil - - // Clear all blocked objects. - for _, other := range os.blocking { - other.blockedBy-- - other.checkComplete(stats) - } - os.blocking = nil - stats.Done() + completeEntry } -// waitFor queues a dependency on the given object. -func (os *objectState) waitFor(other *objectState, callback func()) { - os.blockedBy++ - other.blocking = append(other.blocking, os) - if callback != nil { - other.callbacks = append(other.callbacks, callback) +// addCallback adds a callback to the objectDecodeState. +func (ods *objectDecodeState) addCallback(ic internalCallback) { + if ods.callbacks == nil { + ods.callbacks = ods.callbacksInline[:0] } + ods.callbacks = append(ods.callbacks, ic) } // findCycleFor returns when the given object is found in the blocking set. -func (os *objectState) findCycleFor(target *objectState) []*objectState { - for _, other := range os.blocking { - if other == target { - return []*objectState{target} +func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState { + for _, ic := range ods.callbacks { + other := ic.source() + if other != nil && other == target { + return []*objectDecodeState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } - return nil + + // This should not occur. + Failf("no deadlock found?") + panic("unreachable") } // findCycle finds a dependency cycle. -func (os *objectState) findCycle() []*objectState { - return append(os.findCycleFor(os), os) +func (ods *objectDecodeState) findCycle() []*objectDecodeState { + return append(ods.findCycleFor(ods), ods) +} + +// source implements internalCallback.source. +func (ods *objectDecodeState) source() *objectDecodeState { + return ods +} + +// callbackRun implements internalCallback.callbackRun. +func (ods *objectDecodeState) callbackRun() { + ods.blockedBy-- } // decodeState is a graph of objects in the process of being decoded. @@ -137,30 +141,66 @@ type decodeState struct { // ctx is the decode context. ctx context.Context + // r is the input stream. + r wire.Reader + + // types is the type database. + types typeDecodeDatabase + // objectByID is the set of objects in progress. - objectsByID map[uint64]*objectState + objectsByID []*objectDecodeState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. - deferred map[uint64]*pb.Object + deferred map[objectID]wire.Object - // outstanding is the number of outstanding objects. - outstanding uint32 + // pending is the set of objects that are not yet complete. + pending completeList - // r is the input stream. - r io.Reader - - // stats is the passed stats object. - stats *Stats - - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. -func (ds *decodeState) lookup(id uint64) *objectState { - return ds.objectsByID[id] +func (ds *decodeState) lookup(id objectID) *objectDecodeState { + if len(ds.objectsByID) < int(id) { + return nil + } + return ds.objectsByID[id-1] +} + +// checkComplete checks for completion. +func (ds *decodeState) checkComplete(ods *objectDecodeState) bool { + // Still blocked? + if ods.blockedBy > 0 { + return false + } + + // Track stats if relevant. + if ods.callbacks != nil && ods.typ != 0 { + ds.stats.start(ods.typ) + defer ds.stats.done() + } + + // Fire all callbacks. + for _, ic := range ods.callbacks { + ic.callbackRun() + } + + // Mark completed. + cbs := ods.callbacks + ods.callbacks = nil + ds.pending.Remove(ods) + + // Recursively check others. + for _, ic := range cbs { + if other := ic.source(); other != nil && other.blockedBy == 0 { + ds.checkComplete(other) + } + } + + return true // All set. } // wait registers a dependency on an object. @@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState { // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. -func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { +func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) { switch id { - case 0: - // Nil pointer; nothing to wait for. - fallthrough case waiter.id: // Trivial self reference. fallthrough @@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { return } + // Mark as blocked. + waiter.blockedBy++ + // No nil can be returned here. - waiter.waitFor(ds.lookup(id), callback) + other := ds.lookup(id) + if callback != nil { + // Add the additional user callback. + other.addCallback(userCallback(callback)) + } + + // Mark waiter as unblocked. + other.addCallback(waiter) } // waitObject notes a blocking relationship. -func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { - if rv, ok := p.Value.(*pb.Object_RefValue); ok { +func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) { + if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 { // Refs can encode pointers and maps. - ds.wait(os, rv.RefValue, callback) - } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + ds.wait(ods, objectID(rv.Root), callback) + } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 { // See decodeObject; we need to wait for the array (if non-nil). - ds.wait(os, sv.SliceValue.RefValue, callback) - } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + ds.wait(ods, objectID(sv.Ref.Root), callback) + } else if iv, ok := encoded.(*wire.Interface); ok { // It's an interface (wait recurisvely). - ds.waitObject(os, iv.InterfaceValue.Value, callback) + ds.waitObject(ods, iv.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } +// walkChild returns a child object from obj, given an accessor path. This is +// the decode-side equivalent to traverse in encode.go. +// +// For the purposes of this function, a child object is either a field within a +// struct or an array element, with one such indirection per element in +// path. The returned value may be an unexported field, so it may not be +// directly assignable. See unsafePointerTo. +func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { + // See wire.Ref.Dots. The path here is specified in reverse order. + for i := len(path) - 1; i >= 0; i-- { + switch pc := path[i].(type) { + case *wire.FieldName: // Must be a pointer. + if obj.Kind() != reflect.Struct { + Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj) + } + obj = obj.FieldByName(string(*pc)) + case wire.Index: // Embedded. + if obj.Kind() != reflect.Array { + Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj) + } + obj = obj.Index(int(pc)) + default: + panic("unreachable: switch should be exhaustive") + } + } + return obj +} + // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been -// registered previously. -func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { - os, ok := ds.objectsByID[id] - if ok { - return os +// registered previously. This depends on the type provided if none is +// available in the object itself. +func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value { + // Grow the objectsByID slice. + id := objectID(r.Root) + if len(ds.objectsByID) < int(id) { + ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...) + } + + // Does this object already exist? + ods := ds.objectsByID[id-1] + if ods != nil { + return walkChild(r.Dots, ods.obj) + } + + // Create the object. + if len(r.Dots) != 0 { + typ = ds.findType(r.Type) } + v := reflect.New(typ) + ods = &objectDecodeState{ + id: id, + obj: v.Elem(), + } + ds.objectsByID[id-1] = ods + ds.pending.PushBack(ods) - // Record in the object index. - if typ.Kind() == reflect.Map { - os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} - } else { - os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + // Process any deferred objects & callbacks. + if encoded, ok := ds.deferred[id]; ok { + delete(ds.deferred, id) + ds.decodeObject(ods, ods.obj, encoded) } - ds.objectsByID[id] = os - if o, ok := ds.deferred[id]; ok { - // There is a deferred object. - delete(ds.deferred, id) // Free memory. - ds.decodeObject(os, os.obj, o, "", nil) - } else { - // There is no deferred object. - ds.outstanding++ + return walkChild(r.Dots, ods.obj) +} + +// objectDecoder is for decoding structs. +type objectDecoder struct { + // ds is decodeState. + ds *decodeState + + // ods is current object being decoded. + ods *objectDecodeState + + // reconciledTypeEntry is the reconciled type information. + rte *reconciledTypeEntry + + // encoded is the encoded object state. + encoded *wire.Struct +} + +// load is helper for the public methods on Source. +func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) { + // Note that we have reconciled the type and may remap the fields here + // to match what's expected by the decoder. The "slot" parameter here + // is in terms of the local type, where the fields in the encoded + // object are in terms of the wire object's type, which might be in a + // different order (but will have the same fields). + v := *od.encoded.Field(od.rte.FieldOrder[slot]) + od.ds.decodeObject(od.ods, objPtr.Elem(), v) + if wait { + // Mark this individual object a blocker. + od.ds.waitObject(od.ods, v, fn) } +} - return os +// aterLoad implements Source.AfterLoad. +func (od *objectDecoder) afterLoad(fn func()) { + // Queue the local callback; this will execute when all of the above + // data dependencies have been cleared. + od.ods.addCallback(userCallback(fn)) } // decodeStruct decodes a struct value. -func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { - // Set the fields. - m := Map{newInternalMap(nil, ds, os)} - defer internalMapPool.Put(m.internalMap) - for _, field := range s.Fields { - m.data = append(m.data, entry{ - name: field.Name, - object: field.Value, - }) - } - - // Sort the fields for efficient searching. - // - // Technically, these should already appear in sorted order in the - // state ordering, so this cost is effectively a single scan to ensure - // that the order is correct. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - } - - // Invoke the load; this will recursively decode other objects. - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the loader. - fns.invokeLoad(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow anonymous empty structs. - return - } else { +func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) { + if encoded.TypeID == 0 { + // Allow anonymous empty structs, but only if the encoded + // object also has no fields. + if encoded.Fields() == 0 && obj.NumField() == 0 { + return + } + // Propagate an error. - panic(fmt.Errorf("unregistered type %s", obj.Type())) + Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name()) + } + + // Lookup the object type. + rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type()) + ods.typ = typeID(encoded.TypeID) + + // Invoke the loader. + od := objectDecoder{ + ds: ds, + ods: ods, + rte: rte, + encoded: encoded, + } + ds.stats.start(ods.typ) + defer ds.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateLoad(Source{internal: od}) } } // decodeMap decodes a map value. -func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { +func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) { if obj.IsNil() { + // See pointerTo. obj.Set(reflect.MakeMap(obj.Type())) } - for i := 0; i < len(m.Keys); i++ { + for i := 0; i < len(encoded.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() - ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) - ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) - ds.waitObject(os, m.Keys[i], nil) - ds.waitObject(os, m.Values[i], nil) + ds.decodeObject(ods, kv, encoded.Keys[i]) + ds.decodeObject(ods, vv, encoded.Values[i]) + ds.waitObject(ods, encoded.Keys[i], nil) + ds.waitObject(ods, encoded.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) @@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) } // decodeArray decodes an array value. -func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { - if len(a.Contents) != obj.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) +func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) { + if len(encoded.Contents) != obj.Len() { + Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents)) } // Decode the contents into the array. - for i := 0; i < len(a.Contents); i++ { - ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) - ds.waitObject(os, a.Contents[i], nil) + for i := 0; i < len(encoded.Contents); i++ { + ds.decodeObject(ods, obj.Index(i), encoded.Contents[i]) + ds.waitObject(ods, encoded.Contents[i], nil) } } -// decodeInterface decodes an interface value. -func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { - // Is this a nil value? - if i.Type == "" { - return // Just leave obj alone. +// findType finds the type for the given wire.TypeSpecs. +func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type { + switch x := t.(type) { + case wire.TypeID: + typ := ds.types.LookupType(typeID(x)) + rte := ds.types.Lookup(typeID(x), typ) + return rte.LocalType + case *wire.TypeSpecPointer: + return reflect.PtrTo(ds.findType(x.Type)) + case *wire.TypeSpecArray: + return reflect.ArrayOf(int(x.Count), ds.findType(x.Type)) + case *wire.TypeSpecSlice: + return reflect.SliceOf(ds.findType(x.Type)) + case *wire.TypeSpecMap: + return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value)) + default: + // Should not happen. + Failf("unknown type %#v", t) } + panic("unreachable") +} - // Get the dispatchable type. This may not be used if the given - // reference has already been resolved, but if not we need to know the - // type to create. - t, ok := registeredTypes.lookupType(i.Type) - if !ok { - panic(fmt.Errorf("no valid type for %q", i.Type)) +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) { + if _, ok := encoded.Type.(wire.TypeSpecNil); ok { + // Special case; the nil object. Just decode directly, which + // will read nil from the wire (if encoded correctly). + ds.decodeObject(ods, obj, encoded.Value) + return } - if obj.Kind() != reflect.Map { - // Set the obj to be the given typed value; this actually sets - // obj to be a non-zero value -- namely, it inserts type - // information. There's no need to do this for maps. - obj.Set(reflect.Zero(t)) + // We now need to resolve the actual type. + typ := ds.findType(encoded.Type) + + // We need to imbue type information here, then we can proceed to + // decode normally. In order to avoid issues with setting value-types, + // we create a new non-interface version of this object. We will then + // set the interface object to be equal to whatever we decode. + origObj := obj + obj = reflect.New(typ).Elem() + defer origObj.Set(obj) + + // With the object now having sufficient type information to actually + // have Set called on it, we can proceed to decode the value. + ds.decodeObject(ods, obj, encoded.Value) +} + +// isFloatEq determines if x and y represent the same value. +func isFloatEq(x float64, y float64) bool { + switch { + case math.IsNaN(x): + return math.IsNaN(y) + case math.IsInf(x, 1): + return math.IsInf(y, 1) + case math.IsInf(x, -1): + return math.IsInf(y, -1) + default: + return x == y } +} - // Decode the dereferenced element; there is no need to wait here, as - // the interface object shares the current object state. - ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +// isComplexEq determines if x and y represent the same value. +func isComplexEq(x complex128, y complex128) bool { + return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y)) } // decodeObject decodes a object value. -func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { - ds.push(false, format, param) - ds.stats.Add(obj) - ds.stats.Start(obj) - - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - obj.SetBool(x.BoolValue) - case *pb.Object_StringValue: - obj.SetString(string(x.StringValue)) - case *pb.Object_Int64Value: - obj.SetInt(x.Int64Value) - if obj.Int() != x.Int64Value { - panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) +func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) { + switch x := encoded.(type) { + case wire.Nil: // Fast path: first. + // We leave obj alone here. That's because if obj represents an + // interface, it may have been imbued with type information in + // decodeInterface, and we don't want to destroy that. + case *wire.Ref: + // Nil pointers may be encoded in a "forceValue" context. For + // those we just leave it alone as the value will already be + // correct (nil). + if id := objectID(x.Root); id == 0 { + return } - case *pb.Object_Uint64Value: - obj.SetUint(x.Uint64Value) - if obj.Uint() != x.Uint64Value { - panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_DoubleValue: - obj.SetFloat(x.DoubleValue) - if obj.Float() != x.DoubleValue { - panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_RefValue: - // Resolve the pointer itself, even though the object may not - // be decoded yet. You need to use wait() in order to ensure - // that is the case. See wait above, and Map.Barrier. - if id := x.RefValue; id != 0 { - // Decoding the interface should have imparted type - // information, so from this point it's safe to resolve - // and use this dynamic information for actually - // creating the object in register. - // - // (For non-interfaces this is a no-op). - dyntyp := reflect.TypeOf(obj.Interface()) - if dyntyp.Kind() == reflect.Map { - // Remove the map object count here to avoid - // double counting, as this object will be - // counted again when it gets processed later. - // We do not add a reference count as the - // reference is artificial. - ds.stats.Remove(obj) - obj.Set(ds.register(id, dyntyp).obj) - } else if dyntyp.Kind() == reflect.Ptr { - ds.push(true /* dereference */, "", nil) - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) - ds.pop() - } else { - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + + // Note that if this is a map type, we go through a level of + // indirection to allow for map aliasing. + if obj.Kind() == reflect.Map { + v := ds.register(x, obj.Type()) + if v.IsNil() { + // Note that we don't want to clobber the map + // if has already been decoded by decodeMap. We + // just make it so that we have a consistent + // reference when that eventually does happen. + v.Set(reflect.MakeMap(v.Type())) } - } else { - // We leave obj alone here. That's because if obj - // represents an interface, it may have been embued - // with type information in decodeInterface, and we - // don't want to destroy that information. + obj.Set(v) + return } - case *pb.Object_SliceValue: - // It's okay to slice the array here, since the contents will - // still be provided later on. These semantics are a bit - // strange but they are handled in the Map.Barrier properly. - // - // The special semantics of zero ref apply here too. - if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { - v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) - obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + + // Normal assignment: authoritative only if no dots. + v := ds.register(x, obj.Type().Elem()) + if v.IsValid() { + obj.Set(unsafePointerTo(v)) } - case *pb.Object_ArrayValue: - ds.decodeArray(os, obj, x.ArrayValue) - case *pb.Object_StructValue: - ds.decodeStruct(os, obj, x.StructValue) - case *pb.Object_MapValue: - ds.decodeMap(os, obj, x.MapValue) - case *pb.Object_InterfaceValue: - ds.decodeInterface(os, obj, x.InterfaceValue) - case *pb.Object_ByteArrayValue: - copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Uint16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]uint16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Bool: + obj.SetBool(bool(x)) + case wire.Int: + obj.SetInt(int64(x)) + if obj.Int() != int64(x) { + Failf("signed integer truncated from %v to %v", int64(x), obj.Int()) } - for i := range s { - t[i] = uint16(s[i]) + case wire.Uint: + obj.SetUint(uint64(x)) + if obj.Uint() != uint64(x) { + Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint()) } - case *pb.Object_Uint32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Int16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]int16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Float32: + obj.SetFloat(float64(x)) + case wire.Float64: + obj.SetFloat(float64(x)) + if !isFloatEq(obj.Float(), float64(x)) { + Failf("floating point number truncated from %v to %v", float64(x), obj.Float()) } - for i := range s { - t[i] = int16(s[i]) + case *wire.Complex64: + obj.SetComplex(complex128(*x)) + case *wire.Complex128: + obj.SetComplex(complex128(*x)) + if !isComplexEq(obj.Complex(), complex128(*x)) { + Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex()) } - case *pb.Object_Int32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + case *wire.String: + obj.SetString(string(*x)) + case *wire.Slice: + // See *wire.Ref above; same applies. + if id := objectID(x.Ref.Root); id == 0 { + return + } + // Note that it's fine to slice the array here and assume that + // contents will still be filled in later on. + typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. + v := ds.register(&x.Ref, typ) + obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + case *wire.Array: + ds.decodeArray(ods, obj, x) + case *wire.Struct: + ds.decodeStruct(ods, obj, x) + case *wire.Map: + ds.decodeMap(ods, obj, x) + case *wire.Interface: + ds.decodeInterface(ods, obj, x) default: // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) - } - - ds.stats.Done() - ds.pop() -} - -func copyArray(dest reflect.Value, src reflect.Value) { - if dest.Len() != src.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + Failf("unknown object %#v for %q", encoded, obj.Type().Name()) } - reflect.Copy(dest, castSlice(src, dest.Type().Elem())) } -// Deserialize deserializes the object state. +// Load deserializes the object graph rooted at obj. // // This function may panic and should be run in safely(). -func (ds *decodeState) Deserialize(obj reflect.Value) { - ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} - ds.outstanding = 1 // The root object. +func (ds *decodeState) Load(obj reflect.Value) { + ds.stats.init() + defer ds.stats.fini(func(id typeID) string { + return ds.types.LookupName(id) + }) + + // Create the root object. + ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + id: 1, + obj: obj, + }) + + // Read the number of objects. + lastID, object, err := ReadHeader(ds.r) + if err != nil { + Failf("header error: %w", err) + } + if !object { + Failf("object missing") + } + + // Decode all objects. + var ( + encoded wire.Object + ods *objectDecodeState + id = objectID(1) + tid = typeID(1) + ) + if err := safely(func() { + // Decode all objects in the stream. + // + // Note that the structure of this decoding loop should match + // the raw decoding loop in printer.go. + for id <= objectID(lastID) { + // Unmarshal the object. + encoded = wire.Load(ds.r) + + // Is this a type object? Handle inline. + if wt, ok := encoded.(*wire.Type); ok { + ds.types.Register(wt) + tid++ + encoded = nil + continue + } - // Decode all objects in the stream. - // - // See above, we never process objects while we have no outstanding - // interests (other than the very first object). - for id := uint64(1); ds.outstanding > 0; id++ { - os := ds.lookup(id) - ds.stats.Start(os.obj) - - o, err := ds.readObject() - if err != nil { - panic(err) - } + // Actually resolve the object. + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } - if os != nil { - // Decode the object. - ds.from = &os.path - ds.decodeObject(os, os.obj, o, "", nil) - ds.outstanding-- + // For error handling. + ods = nil + encoded = nil + id++ + } + }); err != nil { + // Include as much information as we can, taking into account + // the possible state transitions above. + if ods != nil { + Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) + } else if encoded != nil { + Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) } else { - // If an object hasn't had interest registered - // previously, we deferred decoding until interest is - // registered. - ds.deferred[id] = o + Failf("general decoding error: %w", err) } - - ds.stats.Done() - } - - // Check the zero-length header at the end. - length, object, err := ReadHeader(ds.r) - if err != nil { - panic(err) - } - if length != 0 { - panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) - } - if object { - panic("expected non-object terminal") } // Check if we have any deferred objects. - if count := len(ds.deferred); count > 0 { - // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("still have %d deferred objects", count)) - } - - // Scan and fire all callbacks. - for _, os := range ds.objectsByID { - os.checkComplete(ds.stats) + for id, encoded := range ds.deferred { + // Shoud never happen, the graph was bogus. + Failf("still have deferred objects: one is ID %d, %#v", id, encoded) } - // Check if we have any remaining dependency cycles. - for _, os := range ds.objectsByID { - if !os.complete() { - // This must be the result of a dependency cycle. - cycle := os.findCycle() - var buf bytes.Buffer - buf.WriteString("dependency cycle: {") - for i, cycleOS := range cycle { - if i > 0 { - buf.WriteString(" => ") + // Scan and fire all callbacks. We iterate over the list of incomplete + // objects until all have been finished. We stop iterating if no + // objects become complete (there is a dependency cycle). + // + // Note that we iterate backwards here, because there will be a strong + // tendendcy for blocking relationships to go from earlier objects to + // later (deeper) objects in the graph. This will reduce the number of + // iterations required to finish all objects. + if err := safely(func() { + for ds.pending.Back() != nil { + thisCycle := false + for ods = ds.pending.Back(); ods != nil; { + if ds.checkComplete(ods) { + thisCycle = true + break } - buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + ods = ods.Prev() + } + if !thisCycle { + break } - buf.WriteString("}") - // Panic as an error; propagate to the caller. - panic(errors.New(string(buf.Bytes()))) } - } -} - -type byteReader struct { - io.Reader -} - -// ReadByte implements io.ByteReader. -func (br byteReader) ReadByte() (byte, error) { - var b [1]byte - n, err := br.Reader.Read(b[:]) - if n > 0 { - return b[0], nil - } else if err != nil { - return 0, err - } else { - return 0, io.ErrUnexpectedEOF + }); err != nil { + Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err) + } + + // Check if we have any remaining dependency cycles. If there are any + // objects left in the pending list, then it must be due to a cycle. + if ods := ds.pending.Front(); ods != nil { + // This must be the result of a dependency cycle. + cycle := ods.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + fmt.Fprintf(&buf, "%q", cycleOS.obj.Type()) + } + buf.WriteString("}") + Failf("incomplete graph: %s", string(buf.Bytes())) } } @@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) { // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. -func ReadHeader(r io.Reader) (length uint64, object bool, err error) { +func ReadHeader(r wire.Reader) (length uint64, object bool, err error) { // Read the header. - length, err = binary.ReadUvarint(byteReader{r}) + err = safely(func() { + length = wire.LoadUint(r) + }) if err != nil { - return + // On the header, pass raw I/O errors. + if sErr, ok := err.(*ErrState); ok { + return 0, false, sErr.Unwrap() + } } // Decode whether the object is valid. - object = length&0x1 != 0 - length = length >> 1 + object = length&objectFlag != 0 + length &^= objectFlag return } - -// readObject reads an object from the stream. -func (ds *decodeState) readObject() (*pb.Object, error) { - // Read the header. - length, object, err := ReadHeader(ds.r) - if err != nil { - return nil, err - } - if !object { - return nil, fmt.Errorf("invalid object header") - } - - // Read the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := ds.r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return nil, err - } - } - - // Unmarshal. - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return nil, err - } - - return obj, nil -} diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go new file mode 100644 index 000000000..d048f61a1 --- /dev/null +++ b/pkg/state/decode_unsafe.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "unsafe" +) + +// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on +// values representing unexported fields. This bypasses visibility, but not +// type safety. +func unsafePointerTo(obj reflect.Value) reflect.Value { + return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index c5118d3a9..92fcad4e9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -15,437 +15,797 @@ package state import ( - "container/list" "context" - "encoding/binary" - "fmt" - "io" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// queuedObject is an object queued for encoding. -type queuedObject struct { - id uint64 - obj reflect.Value - path recoverable +// objectEncodeState the type and identity of an object occupying a memory +// address range. This is the value type for addrSet, and the intrusive entry +// for the pending and deferred lists. +type objectEncodeState struct { + // id is the assigned ID for this object. + id objectID + + // obj is the object value. Note that this may be replaced if we + // encounter an object that contains this object. When this happens (in + // resolve), we will update existing references approprately, below, + // and defer a re-encoding of the object. + obj reflect.Value + + // encoded is the encoded value of this object. Note that this may not + // be up to date if this object is still in the deferred list. + encoded wire.Object + + // how indicates whether this object should be encoded as a value. This + // is used only for deferred encoding. + how encodeStrategy + + // refs are the list of reference objects used by other objects + // referring to this object. When the object is updated, these + // references may be updated directly and automatically. + refs []*wire.Ref + + pendingEntry + deferredEntry } // encodeState is state used for encoding. // -// The encoding process is a breadth-first traversal of the object graph. The -// inherent races and dependencies are much simpler than the decode case. +// The encoding process constructs a representation of the in-memory graph of +// objects before a single object is serialized. This is done to ensure that +// all references can be fully disambiguated. See resolve for more details. type encodeState struct { // ctx is the encode context. ctx context.Context - // lastID is the last object ID. - // - // See idsByObject for context. Because of the special zero encoding - // used for reference values, the first ID must be 1. - lastID uint64 + // w is the output stream. + w wire.Writer - // idsByObject is a set of objects, indexed via: - // - // reflect.ValueOf(x).UnsafeAddr - // - // This provides IDs for objects. - idsByObject map[uintptr]uint64 + // types is the type database. + types typeEncodeDatabase + + // lastID is the last allocated object ID. + lastID objectID - // values stores values that span the addresses. + // values tracks the address ranges occupied by objects, along with the + // types of these objects. This is used to locate pointer targets, + // including pointers to fields within another type. // - // addrSet is a a generated type which efficiently stores ranges of - // addresses. When encoding pointers, these ranges are filled in and - // used to check for overlapping or conflicting pointers. This would - // indicate a pointer to an field, or a non-type safe value, neither of - // which are currently decodable. + // Multiple objects may overlap in memory iff the larger object fully + // contains the smaller one, and the type of the smaller object matches + // a field or array element's type at the appropriate offset. An + // arbitrary number of objects may be nested in this manner. // - // See the usage of values below for more context. + // Note that this does not track zero-sized objects, those are tracked + // by zeroValues below. values addrSet - // w is the output stream. - w io.Writer + // zeroValues tracks zero-sized objects. + zeroValues map[reflect.Type]*objectEncodeState - // pending is the list of objects to be serialized. - // - // This is a set of queuedObjects. - pending list.List + // deferred is the list of objects to be encoded. + deferred deferredList - // done is the a list of finished objects. - // - // This is kept to prevent garbage collection and address reuse. - done list.List + // pendingTypes is the list of types to be serialized. Serialization + // will occur when all objects have been encoded, but before pending is + // serialized. + pendingTypes []wire.Type - // stats is the passed stats object. - stats *Stats + // pending is the list of objects to be serialized. Serialization does + // not actually occur until the full object graph is computed. + pending pendingList - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } -// register looks up an ID, registering if necessary. +// isSameSizeParent returns true if child is a field value or element within +// parent. Only a struct or array can have a child value. +// +// isSameSizeParent deals with objects like this: +// +// struct child { +// // fields.. +// } // -// If the object was not previously registered, it is enqueued to be serialized. -// See the documentation for idsByObject for more information. -func (es *encodeState) register(obj reflect.Value) uint64 { - // It is not legal to call register for any non-pointer objects (see - // below), so we panic with a recoverable error if this is a mismatch. - if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map { - panic(fmt.Errorf("non-pointer %#v registered", obj.Interface())) +// struct parent { +// c child +// } +// +// var p parent +// record(&p.c) +// +// Here, &p and &p.c occupy the exact same address range. +// +// Or like this: +// +// struct child { +// // fields +// } +// +// var arr [1]parent +// record(&arr[0]) +// +// Similarly, &arr[0] and &arr[0].c have the exact same address range. +// +// Precondition: parent and child must occupy the same memory. +func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool { + switch parent.Kind() { + case reflect.Struct: + for i := 0; i < parent.NumField(); i++ { + field := parent.Field(i) + if field.Type() == childType { + return true + } + // Recurse through any intermediate types. + if isSameSizeParent(field, childType) { + return true + } + // Does it make sense to keep going if the first field + // doesn't match? Yes, because there might be an + // arbitrary number of zero-sized fields before we get + // a match, and childType itself can be zero-sized. + } + return false + case reflect.Array: + // The only case where an array with more than one elements can + // return true is if childType is zero-sized. In such cases, + // it's ambiguous which element contains the match since a + // zero-sized child object fully fits in any of the zero-sized + // elements in an array... However since all elements are of + // the same type, we only need to check one element. + // + // For non-zero-sized childTypes, parent.Len() must be 1, but a + // combination of the precondition and an implicit comparison + // between the array element size and childType ensures this. + return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType) + default: + return false } +} - addr := obj.Pointer() - if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 { - // For zero-sized objects, we always provide a unique ID. - // That's because the runtime internally multiplexes pointers - // to the same address. We can't be certain what the intent is - // with pointers to zero-sized objects, so we just give them - // all unique identities. - } else if id, ok := es.idsByObject[addr]; ok { - // Already registered. - return id - } - - // Ensure that the first ID given out is one. See note on lastID. The - // ID zero is used to indicate nil values. +// nextID returns the next valid ID. +func (es *encodeState) nextID() objectID { es.lastID++ - id := es.lastID - es.idsByObject[addr] = id - if obj.Kind() == reflect.Ptr { - // Dereference and treat as a pointer. - es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()}) - - // Register this object at all addresses. - typ := obj.Elem().Type() - if size := typ.Size(); size > 0 { - r := addrRange{addr, addr + size} - if !es.values.IsEmptyRange(r) { - old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable) - panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path())) + return objectID(es.lastID) +} + +// dummyAddr points to the dummy zero-sized address. +var dummyAddr = reflect.ValueOf(new(struct{})).Pointer() + +// resolve records the address range occupied by an object. +func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { + addr := obj.Pointer() + + // Is this a map pointer? Just record the single address. It is not + // possible to take any pointers into the map internals. + if obj.Kind() == reflect.Map { + if addr == 0 { + // Just leave the nil reference alone. This is fine, we + // may need to encode as a reference in this way. We + // return nil for our objectEncodeState so that anyone + // depending on this value knows there's nothing there. + return + } + if seg, _ := es.values.Find(addr); seg.Ok() { + // Ensure the map types match. + existing := seg.Value() + if existing.obj.Type() != obj.Type() { + Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj) } - es.values.Add(r, reflect.ValueOf(es.recoverable.copy())) + + // No sense recording refs, maps may not be replaced by + // covering objects, they are maximal. + ref.Root = wire.Uint(existing.id) + return } + + // Record the map. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + how: encodeMapAsValue, + } + es.values.Add(addrRange{addr, addr + 1}, oes) + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + + // See above: no ref recording. + ref.Root = wire.Uint(oes.id) + return + } + + // If not a map, then the object must be a pointer. + if obj.Kind() != reflect.Ptr { + Failf("attempt to record non-map and non-pointer object %#v", obj) + } + + obj = obj.Elem() // Value from here. + + // Is this a zero-sized type? + typ := obj.Type() + size := typ.Size() + if size == 0 { + if addr == dummyAddr { + // Zero-sized objects point to a dummy byte within the + // runtime. There's no sense recording this in the + // address map. We add this to the dedicated + // zeroValues. + // + // Note that zero-sized objects must be *true* + // zero-sized objects. They cannot be part of some + // larger object. In that case, they are assigned a + // 1-byte address at the end of the object. + oes, ok := es.zeroValues[typ] + if !ok { + oes = &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + es.zeroValues[typ] = oes + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + } + + // There's also no sense tracking back references. We + // know that this is a true zero-sized object, and not + // part of a larger container, so it will not change. + ref.Root = wire.Uint(oes.id) + return + } + size = 1 // See above. + } + + // Calculate the container. + end := addr + size + r := addrRange{addr, end} + if seg, _ := es.values.Find(addr); seg.Ok() { + existing := seg.Value() + switch { + case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): + // The object is a perfect match. Happy path. Avoid the + // traversal and just return directly. We don't need to + // encode the type information or any dots here. + ref.Root = wire.Uint(existing.id) + existing.refs = append(existing.refs, ref) + return + + case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): + // The previously registered object is larger than + // this, no need to update. But we expect some + // traversal below. + + case seg.Start() == addr && seg.End() == end: + if !isSameSizeParent(obj, existing.obj.Type()) { + break // Needs traversal. + } + fallthrough // Needs update. + + case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): + // Update the object and redo the encoding. + old := existing.obj + existing.obj = obj + es.deferred.Remove(existing) + es.deferred.PushBack(existing) + + // The previously registered object is superseded by + // this new object. We are guaranteed to not have any + // mergeable neighbours in this segment set. + if !raceEnabled { + seg.SetRangeUnchecked(r) + } else { + // Add extra paranoid. This will be statically + // removed at compile time unless a race build. + es.values.Remove(seg) + es.values.Add(r, existing) + seg = es.values.LowerBoundSegment(addr) + } + + // Compute the traversal required & update references. + dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) + wt := es.findType(obj.Type()) + for _, ref := range existing.refs { + ref.Dots = append(ref.Dots, dots...) + ref.Type = wt + } + default: + // There is a non-sensical overlap. + Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + } + + // Compute the new reference, record and return it. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) + ref.Type = es.findType(obj.Type()) + existing.refs = append(existing.refs, ref) + return + } + + // The only remaining case is a pointer value that doesn't overlap with + // any registered addresses. Create a new entry for it, and start + // tracking the first reference we just created. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + if !raceEnabled { + es.values.AddWithoutMerging(r, oes) } else { - // Push back the map itself; when maps are encoded from the - // top-level, forceMap will be equal to true. - es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()}) + // Merges should never happen. This is just enabled extra + // sanity checks because the Merge function below will panic. + es.values.Add(r, oes) + } + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) +} + +// traverse searches for a target object within a root object, where the target +// object is a struct field or array element within root, with potentially +// multiple intervening types. traverse returns the set of field or element +// traversals required to reach the target. +// +// Note that for efficiency, traverse returns the dots in the reverse order. +// That is, the first traversal required will be the last element of the list. +// +// Precondition: The target object must lie completely within the range defined +// by [rootAddr, rootAddr + sizeof(rootType)]. +func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot { + // Recursion base case: the types actually match. + if targetType == rootType && targetAddr == rootAddr { + return nil } - return id + switch rootType.Kind() { + case reflect.Struct: + offset := targetAddr - rootAddr + for i := rootType.NumField(); i > 0; i-- { + field := rootType.Field(i - 1) + // The first field from the end with an offset that is + // smaller than or equal to our address offset is where + // the target is located. Traverse from there. + if field.Offset <= offset { + dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr) + fieldName := wire.FieldName(field.Name) + return append(dots, &fieldName) + } + } + // Should never happen; the target should be reachable. + Failf("no field in root type %v contains target type %v", rootType, targetType) + + case reflect.Array: + // Since arrays have homogenous types, all elements have the + // same size and we can compute where the target lives. This + // does not matter for the purpose of typing, but matters for + // the purpose of computing the address of the given index. + elemSize := int(rootType.Elem().Size()) + n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down. + if rootType.Len() < n { + Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements", + targetType, targetAddr, rootType, rootAddr, rootType.Len()) + } + dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr) + return append(dots, wire.Index(n)) + + default: + // For any other type, there's no possibility of aliasing so if + // the types didn't match earlier then we have an addresss + // collision which shouldn't be possible at this point. + Failf("traverse failed for root type %v and target type %v", rootType, targetType) + } + panic("unreachable") } // encodeMap encodes a map. -func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map { - var ( - keys []*pb.Object - values []*pb.Object - ) +func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) { + if obj.IsNil() { + // Because there is a difference between a nil map and an empty + // map, we need to not decode in the case of a truly nil map. + *dest = wire.Nil{} + return + } + l := obj.Len() + m := &wire.Map{ + Keys: make([]wire.Object, l), + Values: make([]wire.Object, l), + } + *dest = m for i, k := range obj.MapKeys() { v := obj.MapIndex(k) - kp := es.encodeObject(k, false, ".(key %d)", i) - vp := es.encodeObject(v, false, "[%#v]", k.Interface()) - keys = append(keys, kp) - values = append(values, vp) + // Map keys must be encoded using the full value because the + // type will be omitted after the first key. + es.encodeObject(k, encodeAsValue, &m.Keys[i]) + es.encodeObject(v, encodeAsValue, &m.Values[i]) } - return &pb.Map{Keys: keys, Values: values} +} + +// objectEncoder is for encoding structs. +type objectEncoder struct { + // es is encodeState. + es *encodeState + + // encoded is the encoded struct. + encoded *wire.Struct +} + +// save is called by the public methods on Sink. +func (oe *objectEncoder) save(slot int, obj reflect.Value) { + fieldValue := oe.encoded.Field(slot) + oe.es.encodeObject(obj, encodeDefault, fieldValue) } // encodeStruct encodes a composite object. -func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct { - // Invoke the save. - m := Map{newInternalMap(es, nil, nil)} - defer internalMapPool.Put(m.internalMap) +func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + // Ensure that the obj is addressable. There are two cases when it is + // not. First, is when this is dispatched via SaveValue. Second, when + // this is a map key as a struct. Either way, we need to make a copy to + // obtain an addressable value. if !obj.CanAddr() { - // Force it to a * type of the above; this involves a copy. localObj := reflect.New(obj.Type()) localObj.Elem().Set(obj) obj = localObj.Elem() } - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the provided saver. - fns.invokeSave(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow unregistered anonymous, empty structs. - return &pb.Struct{} - } else { - // Propagate an error. - panic(fmt.Errorf("unregistered type %T", obj.Interface())) - } - - // Sort the underlying slice, and check for duplicates. This is done - // once instead of on each add, because performing this sort once is - // far more efficient. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - for i := range m.data { - if i > 0 && m.data[i-1].name == m.data[i].name { - panic(fmt.Errorf("duplicate name %s", m.data[i].name)) - } + + // Prepare the value. + s := &wire.Struct{} + *dest = s + + // Look the type up in the database. + te, ok := es.types.Lookup(obj.Type()) + if te == nil { + if obj.NumField() == 0 { + // Allow unregistered anonymous, empty structs. This + // will just return success without ever invoking the + // passed function. This uses the immutable EmptyStruct + // variable to prevent an allocation in this case. + // + // Note that this mechanism does *not* work for + // interfaces in general. So you can't dispatch + // non-registered empty structs via interfaces because + // then they can't be restored. + s.Alloc(0) + return } + // We need a SaverLoader for struct types. + Failf("struct %T does not implement SaverLoader", obj.Interface()) } - - // Encode the resulting fields. - fields := make([]*pb.Field, 0, len(m.data)) - for _, e := range m.data { - fields = append(fields, &pb.Field{ - Name: e.name, - Value: e.object, - }) + if !ok { + // Queue the type to be serialized. + es.pendingTypes = append(es.pendingTypes, te.Type) } - // Return the encoded object. - return &pb.Struct{Fields: fields} + // Invoke the provided saver. + s.TypeID = wire.TypeID(te.ID) + s.Alloc(len(te.Fields)) + oe := objectEncoder{ + es: es, + encoded: s, + } + es.stats.start(te.ID) + defer es.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateSave(Sink{internal: oe}) + } } // encodeArray encodes an array. -func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array { - var ( - contents []*pb.Object - ) - for i := 0; i < obj.Len(); i++ { - entry := es.encodeObject(obj.Index(i), false, "[%d]", i) - contents = append(contents, entry) - } - return &pb.Array{Contents: contents} +func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) { + l := obj.Len() + a := &wire.Array{ + Contents: make([]wire.Object, l), + } + *dest = a + for i := 0; i < l; i++ { + // We need to encode the full value because arrays are encoded + // using the type information from only the first element. + es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i]) + } +} + +// findType recursively finds type information. +func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec { + // First: check if this is a proper type. It's possible for pointers, + // slices, arrays, maps, etc to all have some different type. + te, ok := es.types.Lookup(typ) + if te != nil { + if !ok { + // See encodeStruct. + es.pendingTypes = append(es.pendingTypes, te.Type) + } + return wire.TypeID(te.ID) + } + + switch typ.Kind() { + case reflect.Ptr: + return &wire.TypeSpecPointer{ + Type: es.findType(typ.Elem()), + } + case reflect.Slice: + return &wire.TypeSpecSlice{ + Type: es.findType(typ.Elem()), + } + case reflect.Array: + return &wire.TypeSpecArray{ + Count: wire.Uint(typ.Len()), + Type: es.findType(typ.Elem()), + } + case reflect.Map: + return &wire.TypeSpecMap{ + Key: es.findType(typ.Key()), + Value: es.findType(typ.Elem()), + } + default: + // After potentially chasing many pointers, the + // ultimate type of the object is not known. + Failf("type %q is not known", typ) + } + panic("unreachable") } // encodeInterface encodes an interface. -// -// Precondition: the value is not nil. -func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface { - // Check for the nil interface. - obj = reflect.ValueOf(obj.Interface()) +func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) { + // Dereference the object. + obj = obj.Elem() if !obj.IsValid() { - return &pb.Interface{ - Type: "", // left alone in decode. - Value: &pb.Object{Value: &pb.Object_RefValue{0}}, + // Special case: the nil object. + *dest = &wire.Interface{ + Type: wire.TypeSpecNil{}, + Value: wire.Nil{}, } + return } - // We have an interface value here. How do we save that? We - // resolve the underlying type and save it as a dispatchable. - typName, ok := registeredTypes.lookupName(obj.Type()) - if !ok { - panic(fmt.Errorf("type %s is not registered", obj.Type())) + + // Encode underlying object. + i := &wire.Interface{ + Type: es.findType(obj.Type()), } + *dest = i + es.encodeObject(obj, encodeAsValue, &i.Value) +} - // Encode the object again. - return &pb.Interface{ - Type: typName, - Value: es.encodeObject(obj, false, ".(%s)", typName), +// isPrimitive returns true if this is a primitive object, or a composite +// object composed entirely of primitives. +func isPrimitiveZero(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Ptr: + // Pointers are always treated as primitive types because we + // won't encode directly from here. Returning true here won't + // prevent the object from being encoded correctly. + return true + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Slice: + // The slice itself a primitive, but not necessarily the array + // that points to. This is similar to a pointer. + return true + case reflect.Array: + // We cannot treat an array as a primitive, because it may be + // composed of structures or other things with side-effects. + return isPrimitiveZero(typ.Elem()) + case reflect.Interface: + // Since we now that this type is the zero type, the interface + // value must be zero. Therefore this is primitive. + return true + case reflect.Struct: + return false + case reflect.Map: + // The isPrimitiveZero function is called only on zero-types to + // see if it's safe to serialize. Since a zero map has no + // elements, it is safe to treat as a primitive. + return true + default: + Failf("unknown type %q", typ.Name()) } + panic("unreachable") } -// encodeObject encodes an object. -// -// If mapAsValue is true, then a map will be encoded directly. -func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) { - es.push(false, format, param) - es.stats.Add(obj) - es.stats.Start(obj) +// encodeStrategy is the strategy used for encodeObject. +type encodeStrategy int +const ( + // encodeDefault means types are encoded normally as references. + encodeDefault encodeStrategy = iota + + // encodeAsValue means that types will never take short-circuited and + // will always be encoded as a normal value. + encodeAsValue + + // encodeMapAsValue means that even maps will be fully encoded. + encodeMapAsValue +) + +// encodeObject encodes an object. +func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) { + if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() { + *dest = wire.Nil{} + return + } switch obj.Kind() { + case reflect.Ptr: // Fast path: first. + r := new(wire.Ref) + *dest = r + if obj.IsNil() { + // May be in an array or elsewhere such that a value is + // required. So we encode as a reference to the zero + // object, which does not exist. Note that this has to + // be handled correctly in the decode path as well. + return + } + es.resolve(obj, r) case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}} + *dest = wire.Bool(obj.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}} + *dest = wire.Int(obj.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}} - case reflect.Float32, reflect.Float64: - object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}} + *dest = wire.Uint(obj.Uint()) + case reflect.Float32: + *dest = wire.Float32(obj.Float()) + case reflect.Float64: + *dest = wire.Float64(obj.Float()) + case reflect.Complex64: + c := wire.Complex64(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.Complex128: + c := wire.Complex128(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.String: + s := wire.String(obj.String()) + *dest = &s // Needs alloc. case reflect.Array: - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}} - case reflect.Uint16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]uint16) - t := make([]uint32, len(s)) - for i := range s { - t[i] = uint32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}} - case reflect.Uint32: - object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}} - case reflect.Uint64: - object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Int8: - object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}} - case reflect.Int16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]int16) - t := make([]int32, len(s)) - for i := range s { - t[i] = int32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}} - case reflect.Int32: - object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}} - case reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}} - case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}} - case reflect.Float32: - object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}} - case reflect.Float64: - object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}} - default: - object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}} - } + es.encodeArray(obj, dest) case reflect.Slice: - if obj.IsNil() || obj.Cap() == 0 { - // Handled specially in decode; store as nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - // Serialize a slice as the array plus length and capacity. - object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{ - Capacity: uint32(obj.Cap()), - Length: uint32(obj.Len()), - RefValue: es.register(arrayFromSlice(obj)), - }}} + s := &wire.Slice{ + Capacity: wire.Uint(obj.Cap()), + Length: wire.Uint(obj.Len()), } - case reflect.String: - object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}} - case reflect.Ptr: + *dest = s + // Note that we do need to provide a wire.Slice type here as + // how is not encodeDefault. If this were the case, then it + // would have been caught by the IsZero check above and we + // would have just used wire.Nil{}. if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - es.push(true /* dereference */, "", nil) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} - es.pop() + return } + // Slices need pointer resolution. + es.resolve(arrayFromSlice(obj), &s.Ref) case reflect.Interface: - // We don't check for IsNil here, as we want to encode type - // information. The case of the empty interface (no type, no - // value) is handled by encodeInteface. - object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}} + es.encodeInterface(obj, dest) case reflect.Struct: - object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}} + es.encodeStruct(obj, dest) case reflect.Map: - if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else if mapAsValue { - // Encode the map directly. - object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}} - } else { - // Encode a reference to the map. - // - // Remove the map object count here to avoid double - // counting, as this object will be counted again when - // it gets processed later. We do not add a reference - // count as the reference is artificial. - es.stats.Remove(obj) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} + if how == encodeMapAsValue { + es.encodeMap(obj, dest) + return } + r := new(wire.Ref) + *dest = r + es.resolve(obj, r) default: - panic(fmt.Errorf("unknown primitive %#v", obj.Interface())) + Failf("unknown object %#v", obj.Interface()) + panic("unreachable") } - - es.stats.Done() - es.pop() - return } -// Serialize serializes the object state. -// -// This function may panic and should be run in safely(). -func (es *encodeState) Serialize(obj reflect.Value) { - es.register(obj.Addr()) - - // Pop off the list until we're done. - for es.pending.Len() > 0 { - e := es.pending.Front() - - // Extract the queued object. - qo := e.Value.(queuedObject) - es.stats.Start(qo.obj) +// Save serializes the object graph rooted at obj. +func (es *encodeState) Save(obj reflect.Value) { + es.stats.init() + defer es.stats.fini(func(id typeID) string { + return es.pendingTypes[id-1].Name + }) + + // Resolve the first object, which should queue a pile of additional + // objects on the pending list. All queued objects should be fully + // resolved, and we should be able to serialize after this call. + var root wire.Ref + es.resolve(obj.Addr(), &root) + + // Encode the graph. + var oes *objectEncodeState + if err := safely(func() { + for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() { + // Remove and encode the object. Note that as a result + // of this encoding, the object may be enqueued on the + // deferred list yet again. That's expected, and why it + // is removed first. + es.deferred.Remove(oes) + es.encodeObject(oes.obj, oes.how, &oes.encoded) + } + }); err != nil { + // Include the object in the error message. + Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) + } - es.pending.Remove(e) + // Check that items are pending. + if es.pending.Front() == nil { + Failf("pending is empty?") + } - es.from = &qo.path - o := es.encodeObject(qo.obj, true, "", nil) + // Write the header with the number of objects. Note that there is no + // way that es.lastID could conflict with objectID, which would + // indicate that an impossibly large encoding. + if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + Failf("error writing header: %w", err) + } - // Emit to our output stream. - if err := es.writeObject(qo.id, o); err != nil { - panic(err) + // Serialize all pending types and pending objects. Note that we don't + // bother removing from this list as we walk it because that just + // wastes time. It will not change after this point. + var id objectID + if err := safely(func() { + for _, wt := range es.pendingTypes { + // Encode the type. + wire.Save(es.w, &wt) } + for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { + id++ // First object is 1. + if oes.id != id { + Failf("expected id %d, got %d", id, oes.id) + } - // Mark as done. - es.done.PushBack(e) - es.stats.Done() + // Marshall the object. + wire.Save(es.w, oes.encoded) + } + }); err != nil { + // Include the object and the error. + Failf("error serializing object %#v: %w", oes.encoded, err) } - // Write a zero-length terminal at the end; this is a sanity check - // applied at decode time as well (see decode.go). - if err := WriteHeader(es.w, 0, false); err != nil { - panic(err) + // Check what we wrote. + if id != es.lastID { + Failf("expected %d objects, wrote %d", es.lastID, id) } } +// objectFlag indicates that the length is a # of objects, rather than a raw +// byte length. When this is set on a length header in the stream, it may be +// decoded appropriately. +const objectFlag uint64 = 1 << 63 + // WriteHeader writes a header. // // Each object written to the statefile should be prefixed with a header. In // order to generate statefiles that play nicely with debugging tools, raw // writes should be prefixed with a header with object set to false and the // appropriate length. This will allow tools to skip these regions. -func WriteHeader(w io.Writer, length uint64, object bool) error { - // The lowest-order bit encodes whether this is a valid object. This is - // a purely internal convention, but allows the object flag to be - // returned from ReadHeader. - length = length << 1 +func WriteHeader(w wire.Writer, length uint64, object bool) error { + // Sanity check the length. + if length&objectFlag != 0 { + Failf("impossibly huge length: %d", length) + } if object { - length |= 0x1 + length |= objectFlag } // Write a header. - var hdr [32]byte - encodedLen := binary.PutUvarint(hdr[:], length) - for done := 0; done < encodedLen; { - n, err := w.Write(hdr[done:encodedLen]) - done += n - if n == 0 && err != nil { - return err - } - } - - return nil + return safely(func() { + wire.SaveUint(w, length) + }) } -// writeObject writes an object to the stream. -func (es *encodeState) writeObject(id uint64, obj *pb.Object) error { - // Marshal the proto. - buf, err := proto.Marshal(obj) - if err != nil { - return err - } +// pendingMapper is for the pending list. +type pendingMapper struct{} - // Write the object header. - if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil { - return err - } +func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // Write the object. - for done := 0; done < len(buf); { - n, err := es.w.Write(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } +// deferredMapper is for the deferred list. +type deferredMapper struct{} - return nil -} +func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry } // addrSetFunctions is used by addrSet. type addrSetFunctions struct{} @@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr { return ^uintptr(0) } -func (addrSetFunctions) ClearValue(val *reflect.Value) { +func (addrSetFunctions) ClearValue(val **objectEncodeState) { + *val = nil } -func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) { - return val1, val1 == val2 +func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) { + if val1.obj == val2.obj { + // This, should never happen. It would indicate that the same + // object exists in two non-contiguous address ranges. Note + // that this assertion can only be triggered if the race + // detector is enabled. + Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj) + } + // Reject the merge. + return val1, false } -func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) { - return val, val +func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) { + // A split should never happen: we don't remove ranges. + Failf("unexpected split in addrSet @ %v: %#v", r, val.obj) + panic("unreachable") } diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go index 457e6dbb7..e0dad83b4 100644 --- a/pkg/state/encode_unsafe.go +++ b/pkg/state/encode_unsafe.go @@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value { reflect.ArrayOf(obj.Cap(), obj.Type().Elem()), unsafe.Pointer(obj.Pointer())) } - -// pbSlice returns a protobuf-supported slice of the array and erase the -// original element type (which could be a defined type or non-supported type). -func pbSlice(obj reflect.Value) reflect.Value { - var typ reflect.Type - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - typ = reflect.TypeOf(byte(0)) - case reflect.Uint16: - typ = reflect.TypeOf(uint16(0)) - case reflect.Uint32: - typ = reflect.TypeOf(uint32(0)) - case reflect.Uint64: - typ = reflect.TypeOf(uint64(0)) - case reflect.Uintptr: - typ = reflect.TypeOf(uint64(0)) - case reflect.Int8: - typ = reflect.TypeOf(byte(0)) - case reflect.Int16: - typ = reflect.TypeOf(int16(0)) - case reflect.Int32: - typ = reflect.TypeOf(int32(0)) - case reflect.Int64: - typ = reflect.TypeOf(int64(0)) - case reflect.Bool: - typ = reflect.TypeOf(bool(false)) - case reflect.Float32: - typ = reflect.TypeOf(float32(0)) - case reflect.Float64: - typ = reflect.TypeOf(float64(0)) - default: - panic("slice element is not of basic value type") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), typ), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem().Slice(0, obj.Len()) -} - -func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value { - if obj.Type().Elem().Size() != elemTyp.Size() { - panic("cannot cast slice into other element type of different size") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), elemTyp), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem() -} diff --git a/pkg/state/map.go b/pkg/state/map.go deleted file mode 100644 index 4f3ebb0da..000000000 --- a/pkg/state/map.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "context" - "fmt" - "reflect" - "sort" - "sync" - - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// entry is a single map entry. -type entry struct { - name string - object *pb.Object -} - -// internalMap is the internal Map state. -// -// These are recycled via a pool to avoid churn. -type internalMap struct { - // es is encodeState. - es *encodeState - - // ds is decodeState. - ds *decodeState - - // os is current object being decoded. - // - // This will always be nil during encode. - os *objectState - - // data stores the encoded values. - data []entry -} - -var internalMapPool = sync.Pool{ - New: func() interface{} { - return new(internalMap) - }, -} - -// newInternalMap returns a cached map. -func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap { - m := internalMapPool.Get().(*internalMap) - m.es = es - m.ds = ds - m.os = os - if m.data != nil { - m.data = m.data[:0] - } - return m -} - -// Map is a generic state container. -// -// This is the object passed to Save and Load in order to store their state. -// -// Detailed documentation is available in individual methods. -type Map struct { - *internalMap -} - -// Save adds the given object to the map. -// -// You should pass always pointers to the object you are saving. For example: -// -// type X struct { -// A int -// B *int -// } -// -// func (x *X) Save(m Map) { -// m.Save("A", &x.A) -// m.Save("B", &x.B) -// } -// -// func (x *X) Load(m Map) { -// m.Load("A", &x.A) -// m.Load("B", &x.B) -// } -func (m Map) Save(name string, objPtr interface{}) { - m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s") -} - -// SaveValue adds the given object value to the map. -// -// This should be used for values where pointers are not available, or casts -// are required during Save/Load. -// -// For example, if we want to cast external package type P.Foo to int64: -// -// type X struct { -// A P.Foo -// } -// -// func (x *X) Save(m Map) { -// m.SaveValue("A", int64(x.A)) -// } -// -// func (x *X) Load(m Map) { -// m.LoadValue("A", new(int64), func(x interface{}) { -// x.A = P.Foo(x.(int64)) -// }) -// } -func (m Map) SaveValue(name string, obj interface{}) { - m.save(name, reflect.ValueOf(obj), ".(value %s)") -} - -// save is helper for the above. It takes the name of value to save the field -// to, the field object (obj), and a format string that specifies how the -// field's saving logic is dispatched from the struct (normal, value, etc.). The -// format string should expect one string parameter, which is the name of the -// field. -func (m Map) save(name string, obj reflect.Value, format string) { - if m.es == nil { - // Not currently encoding. - m.Failf("no encode state for %q", name) - } - - // Attempt the encode. - // - // These are sorted at the end, after all objects are added and will be - // sorted and checked for duplicates (see encodeStruct). - m.data = append(m.data, entry{ - name: name, - object: m.es.encodeObject(obj, false, format, name), - }) -} - -// Load loads the given object from the map. -// -// See Save for an example. -func (m Map) Load(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s") -} - -// LoadWait loads the given objects from the map, and marks it as requiring all -// AfterLoad executions to complete prior to running this object's AfterLoad. -// -// See Save for an example. -func (m Map) LoadWait(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)") -} - -// LoadValue loads the given object value from the map. -// -// See SaveValue for an example. -func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) { - o := reflect.ValueOf(objPtr) - m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)") -} - -// load is helper for the above. It takes the name of value to load the field -// from, the target field pointer (objPtr), whether load completion of the -// struct depends on the field's load completion (wait), the load completion -// logic (fn), and a format string that specifies how the field's loading logic -// is dispatched from the struct (normal, wait, value, etc.). The format string -// should expect one string parameter, which is the name of the field. -func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) { - if m.ds == nil { - // Not currently decoding. - m.Failf("no decode state for %q", name) - } - - // Find the object. - // - // These are sorted up front (and should appear in the state file - // sorted as well), so we can do a binary search here to ensure that - // large structs don't behave badly. - i := sort.Search(len(m.data), func(i int) bool { - return m.data[i].name >= name - }) - if i >= len(m.data) || m.data[i].name != name { - // There is no data for this name? - m.Failf("no data found for %q", name) - } - - // Perform the decode. - m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name) - if wait { - // Mark this individual object a blocker. - m.ds.waitObject(m.os, m.data[i].object, fn) - } -} - -// Failf fails the save or restore with the provided message. Processing will -// stop after calling Failf, as the state package uses a panic & recover -// mechanism for state errors. You should defer any cleanup required. -func (m Map) Failf(format string, args ...interface{}) { - panic(fmt.Errorf(format, args...)) -} - -// AfterLoad schedules a function execution when all objects have been allocated -// and their automated loading and customized load logic have been executed. fn -// will not be executed until all of current object's dependencies' AfterLoad() -// logic, if exist, have been executed. -func (m Map) AfterLoad(fn func()) { - if m.ds == nil { - // Not currently decoding. - m.Failf("not decoding") - } - - // Queue the local callback; this will execute when all of the above - // data dependencies have been cleared. - m.os.callbacks = append(m.os.callbacks, fn) -} - -// Context returns the current context object. -func (m Map) Context() context.Context { - if m.es != nil { - return m.es.ctx - } else if m.ds != nil { - return m.ds.ctx - } - return context.Background() // No context. -} diff --git a/pkg/state/object.proto b/pkg/state/object.proto deleted file mode 100644 index 5ebcfb151..000000000 --- a/pkg/state/object.proto +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor.state.statefile; - -// Slice is a slice value. -message Slice { - uint32 length = 1; - uint32 capacity = 2; - uint64 ref_value = 3; -} - -// Array is an array value. -message Array { - repeated Object contents = 1; -} - -// Map is a map value. -message Map { - repeated Object keys = 1; - repeated Object values = 2; -} - -// Interface is an interface value. -message Interface { - string type = 1; - Object value = 2; -} - -// Struct is a basic composite value. -message Struct { - repeated Field fields = 1; -} - -// Field encodes a single field. -message Field { - string name = 1; - Object value = 2; -} - -// Uint16s encodes an uint16 array. To be used inside oneof structure. -message Uint16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated uint32 values = 1; -} - -// Uint32s encodes an uint32 array. To be used inside oneof structure. -message Uint32s { - repeated fixed32 values = 1; -} - -// Uint64s encodes an uint64 array. To be used inside oneof structure. -message Uint64s { - repeated fixed64 values = 1; -} - -// Uintptrs encodes an uintptr array. To be used inside oneof structure. -message Uintptrs { - repeated fixed64 values = 1; -} - -// Int8s encodes an int8 array. To be used inside oneof structure. -message Int8s { - bytes values = 1; -} - -// Int16s encodes an int16 array. To be used inside oneof structure. -message Int16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated int32 values = 1; -} - -// Int32s encodes an int32 array. To be used inside oneof structure. -message Int32s { - repeated sfixed32 values = 1; -} - -// Int64s encodes an int64 array. To be used inside oneof structure. -message Int64s { - repeated sfixed64 values = 1; -} - -// Bools encodes a boolean array. To be used inside oneof structure. -message Bools { - repeated bool values = 1; -} - -// Float64s encodes a float64 array. To be used inside oneof structure. -message Float64s { - repeated double values = 1; -} - -// Float32s encodes a float32 array. To be used inside oneof structure. -message Float32s { - repeated float values = 1; -} - -// Object are primitive encodings. -// -// Note that ref_value references an Object.id, below. -message Object { - oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; - } -} diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD new file mode 100644 index 000000000..d053802f7 --- /dev/null +++ b/pkg/state/pretty/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pretty", + srcs = ["pretty.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go new file mode 100644 index 000000000..cf37aaa49 --- /dev/null +++ b/pkg/state/pretty/pretty.go @@ -0,0 +1,273 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pretty is a pretty-printer for state streams. +package pretty + +import ( + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +func formatRef(x *wire.Ref, graph uint64, html bool) string { + baseRef := fmt.Sprintf("g%dr%d", graph, x.Root) + fullRef := baseRef + if len(x.Dots) > 0 { + // See wire.Ref; Type valid if Dots non-zero. + typ, _ := formatType(x.Type, graph, html) + var buf strings.Builder + buf.WriteString("(*") + buf.WriteString(typ) + buf.WriteString(")(") + buf.WriteString(baseRef) + for _, component := range x.Dots { + switch v := component.(type) { + case *wire.FieldName: + buf.WriteString(".") + buf.WriteString(string(*v)) + case wire.Index: + buf.WriteString(fmt.Sprintf("[%d]", v)) + default: + panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) + } + } + buf.WriteString(")") + fullRef = buf.String() + } + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef) + } + return fullRef +} + +func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { + switch x := t.(type) { + case wire.TypeID: + base := fmt.Sprintf("g%dt%d", graph, x) + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true + } + return fmt.Sprintf("%s", base), true + case wire.TypeSpecNil: + return "", false // Only nil type. + case *wire.TypeSpecPointer: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("(*%s)", element), true + case *wire.TypeSpecArray: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("[%d](%s)", x.Count, element), true + case *wire.TypeSpecSlice: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("([]%s)", element), true + case *wire.TypeSpecMap: + key, _ := formatType(x.Key, graph, html) + value, _ := formatType(x.Value, graph, html) + return fmt.Sprintf("(map[%s]%s)", key, value), true + default: + panic(fmt.Sprintf("unreachable: unknown type %T", t)) + } +} + +// format formats a single object, for pretty-printing. It also returns whether +// the value is a non-zero value. +func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) { + switch x := encoded.(type) { + case wire.Nil: + return "nil", false + case *wire.String: + return fmt.Sprintf("%q", *x), *x != "" + case *wire.Complex64: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Complex128: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Ref: + return formatRef(x, graph, html), x.Root != 0 + case *wire.Type: + tabs := "\n" + strings.Repeat("\t", depth) + items := make([]string, 0, len(x.Fields)+2) + items = append(items, fmt.Sprintf("type %s {", x.Name)) + for i := 0; i < len(x.Fields); i++ { + items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i])) + } + items = append(items, "}") + return strings.Join(items, tabs), true // No zero value. + case *wire.Slice: + return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0 + case *wire.Array: + if len(x.Contents) == 0 { + return "[]", false + } + items := make([]string, 0, len(x.Contents)+2) + zeros := make([]string, 0) // used to eliminate zero entries. + items = append(items, "[") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Contents); i++ { + item, ok := format(graph, depth+1, x.Contents[i], html) + if !ok { + zeros = append(zeros, fmt.Sprintf("\t%s,", item)) + continue + } + if len(zeros) > 0 { + items = append(items, zeros...) + zeros = nil + } + items = append(items, fmt.Sprintf("\t%s,", item)) + } + if len(zeros) > 0 { + items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) + } + items = append(items, "]") + return strings.Join(items, tabs), len(zeros) < len(x.Contents) + case *wire.Struct: + typ, _ := formatType(x.TypeID, graph, html) + if x.Fields() == 0 { + return fmt.Sprintf("struct[%s]{}", typ), false + } + items := make([]string, 0, 2) + items = append(items, fmt.Sprintf("struct[%s]{", typ)) + tabs := "\n" + strings.Repeat("\t", depth) + allZero := true + for i := 0; i < x.Fields(); i++ { + element, ok := format(graph, depth+1, *x.Field(i), html) + allZero = allZero && !ok + items = append(items, fmt.Sprintf("\t%d: %s,", i, element)) + i++ + } + items = append(items, "}") + return strings.Join(items, tabs), !allZero + case *wire.Map: + if len(x.Keys) == 0 { + return "map{}", false + } + items := make([]string, 0, len(x.Keys)+2) + items = append(items, "map{") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Keys); i++ { + key, _ := format(graph, depth+1, x.Keys[i], html) + value, _ := format(graph, depth+1, x.Values[i], html) + items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) + } + items = append(items, "}") + return strings.Join(items, tabs), true + case *wire.Interface: + typ, typOk := formatType(x.Type, graph, html) + element, elementOk := format(graph, depth+1, x.Value, html) + return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk + default: + // Must be a primitive; use reflection. + return fmt.Sprintf("%v", encoded), true + } +} + +// printStream is the basic print implementation. +func printStream(w io.Writer, r wire.Reader, html bool) (err error) { + // current graph ID. + var graph uint64 + + if html { + fmt.Fprintf(w, "<pre>") + defer fmt.Fprintf(w, "</pre>") + } + + defer func() { + if r := recover(); r != nil { + if rErr, ok := r.(error); ok { + err = rErr // Override return. + return + } + panic(r) // Propagate. + } + }() + + for { + // Find the first object to begin generation. + length, object, err := state.ReadHeader(r) + if err == io.EOF { + // Nothing else to do. + break + } else if err != nil { + return err + } + if !object { + graph++ // Increment the graph. + if length > 0 { + fmt.Fprintf(w, "(%d bytes non-object data)\n", length) + io.Copy(ioutil.Discard, &io.LimitedReader{ + R: r, + N: int64(length), + }) + } + continue + } + + // Read & unmarshal the object. + // + // Note that this loop must match the general structure of the + // loop in decode.go. But we don't register type information, + // etc. and just print the raw structures. + var ( + oid uint64 = 1 + tid uint64 = 1 + ) + for oid <= length { + // Unmarshal the object. + encoded := wire.Load(r) + + // Is this a type? + if _, ok := encoded.(*wire.Type); ok { + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dt%d", graph, tid) + if html { + // See below. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + tid++ + continue + } + + // Format the node. + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dr%d", graph, oid) + if html { + // Create a little tag with an anchor next to it for linking. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + oid++ + } + } + + return nil +} + +// PrintText reads the stream from r and prints text to w. +func PrintText(w io.Writer, r wire.Reader) error { + return printStream(w, r, false /* html */) +} + +// PrintHTML reads the stream from r and prints html to w. +func PrintHTML(w io.Writer, r wire.Reader) error { + return printStream(w, r, true /* html */) +} diff --git a/pkg/state/printer.go b/pkg/state/printer.go deleted file mode 100644 index 3ce18242f..000000000 --- a/pkg/state/printer.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "fmt" - "io" - "io/ioutil" - "reflect" - "strings" - - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// format formats a single object, for pretty-printing. It also returns whether -// the value is a non-zero value. -func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) { - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false - case *pb.Object_StringValue: - return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0 - case *pb.Object_Int64Value: - return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0 - case *pb.Object_Uint64Value: - return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0 - case *pb.Object_DoubleValue: - return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0 - case *pb.Object_RefValue: - if x.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return ref, true - case *pb.Object_SliceValue: - if x.SliceValue.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true - case *pb.Object_ArrayValue: - if len(x.ArrayValue.Contents) == 0 { - return "[]", false - } - items := make([]string, 0, len(x.ArrayValue.Contents)+2) - zeros := make([]string, 0) // used to eliminate zero entries. - items = append(items, "[") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.ArrayValue.Contents); i++ { - item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html) - if ok { - if len(zeros) > 0 { - items = append(items, zeros...) - zeros = nil - } - items = append(items, fmt.Sprintf("\t%s,", item)) - } else { - zeros = append(zeros, fmt.Sprintf("\t%s,", item)) - } - } - if len(zeros) > 0 { - items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) - } - items = append(items, "]") - return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents) - case *pb.Object_StructValue: - if len(x.StructValue.Fields) == 0 { - return "struct{}", false - } - items := make([]string, 0, len(x.StructValue.Fields)+2) - items = append(items, "struct{") - tabs := "\n" + strings.Repeat("\t", depth) - allZero := true - for _, field := range x.StructValue.Fields { - element, ok := format(graph, depth+1, field.Value, html) - allZero = allZero && !ok - items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element)) - } - items = append(items, "}") - return strings.Join(items, tabs), !allZero - case *pb.Object_MapValue: - if len(x.MapValue.Keys) == 0 { - return "map{}", false - } - items := make([]string, 0, len(x.MapValue.Keys)+2) - items = append(items, "map{") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.MapValue.Keys); i++ { - key, _ := format(graph, depth+1, x.MapValue.Keys[i], html) - value, _ := format(graph, depth+1, x.MapValue.Values[i], html) - items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) - } - items = append(items, "}") - return strings.Join(items, tabs), true - case *pb.Object_InterfaceValue: - if x.InterfaceValue.Type == "" { - return "interface(nil){}", false - } - element, _ := format(graph, depth+1, x.InterfaceValue.Value, html) - return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true - case *pb.Object_ByteArrayValue: - return printArray(reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values)) - case *pb.Object_Uint32ArrayValue: - return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - return printArray(reflect.ValueOf(x.Int16ArrayValue.Values)) - case *pb.Object_Int32ArrayValue: - return printArray(reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - return printArray(reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - return printArray(reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - return printArray(reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - return printArray(reflect.ValueOf(x.Float32ArrayValue.Values)) - } - - // Should not happen, but tolerate. - return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true -} - -// PrettyPrint reads the state stream from r, and pretty prints to w. -func PrettyPrint(w io.Writer, r io.Reader, html bool) error { - var ( - // current graph ID. - graph uint64 - - // current object ID. - id uint64 - ) - - if html { - fmt.Fprintf(w, "<pre>") - defer fmt.Fprintf(w, "</pre>") - } - - for { - // Find the first object to begin generation. - length, object, err := ReadHeader(r) - if err == io.EOF { - // Nothing else to do. - break - } else if err != nil { - return err - } - if !object { - // Increment the graph number & reset the ID. - graph++ - id = 0 - if length > 0 { - fmt.Fprintf(w, "(%d bytes non-object data)\n", length) - io.Copy(ioutil.Discard, &io.LimitedReader{ - R: r, - N: int64(length), - }) - } - continue - } - - // Read & unmarshal the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return err - } - - id++ // First object must be one. - str, _ := format(graph, 0, obj, html) - tag := fmt.Sprintf("g%dr%d", graph, id) - if html { - tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag) - } - if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { - return err - } - } - - return nil -} - -func printArray(s reflect.Value) (string, bool) { - zero := reflect.Zero(s.Type().Elem()).Interface() - z := "0" - switch s.Type().Elem().Kind() { - case reflect.Bool: - z = "false" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - case reflect.Float32, reflect.Float64: - default: - return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true - } - - zeros := 0 - items := make([]string, 0, s.Len()) - for i := 0; i <= s.Len(); i++ { - if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) { - zeros++ - continue - } - if zeros > 0 { - if zeros <= 4 { - for ; zeros > 0; zeros-- { - items = append(items, z) - } - } else { - items = append(items, fmt.Sprintf("(%d %ss)", zeros, z)) - zeros = 0 - } - } - if i < s.Len() { - items = append(items, fmt.Sprintf("%v", s.Index(i).Interface())) - } - } - return "[" + strings.Join(items, ",") + "]", zeros < s.Len() -} diff --git a/pkg/state/state.go b/pkg/state/state.go index 03ae2dbb0..acb629969 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -31,210 +31,226 @@ // Uint64 default // Float32 default // Float64 default -// Complex64 custom -// Complex128 custom +// Complex64 default +// Complex128 default // Array default // Chan custom // Func custom -// Interface custom -// Map default (*) +// Interface default +// Map default // Ptr default // Slice default // String default -// Struct custom +// Struct custom (*) Unless zero-sized. // UnsafePointer custom // -// (*) Maps are treated as value types by this package, even if they are -// pointers internally. If you want to save two independent references -// to the same map value, you must explicitly use a pointer to a map. +// See README.md for an overview of how encoding and decoding works. package state import ( "context" "fmt" - "io" "reflect" "runtime" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) +// objectID is a unique identifier assigned to each object to be serialized. +// Each instance of an object is considered separately, i.e. if there are two +// objects of the same type in the object graph being serialized, they'll be +// assigned unique objectIDs. +type objectID uint32 + +// typeID is the identifier for a type. Types are serialized and tracked +// alongside objects in order to avoid the overhead of encoding field names in +// all objects. +type typeID uint32 + // ErrState is returned when an error is encountered during encode/decode. type ErrState struct { // err is the underlying error. err error - // path is the visit path from root to the current object. - path string - // trace is the stack trace. trace string } // Error returns a sensible description of the state error. func (e *ErrState) Error() string { - return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace) + return fmt.Sprintf("%v:\n%s", e.err, e.trace) } -// UnwrapErrState returns the underlying error in ErrState. -// -// If err is not *ErrState, err is returned directly. -func UnwrapErrState(err error) error { - if e, ok := err.(*ErrState); ok { - return e.err - } - return err +// Unwrap implements standard unwrapping. +func (e *ErrState) Unwrap() error { + return e.err } // Save saves the given object state. -func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. - es := &encodeState{ - ctx: ctx, - idsByObject: make(map[uintptr]uint64), - w: w, - stats: stats, + es := encodeState{ + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), } // Perform the encoding. - return es.safely(func() { - es.Serialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + es.Save(reflect.ValueOf(rootPtr).Elem()) }) + return es.stats, err } // Load loads a checkpoint. -func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) { // Create the decoding state. - ds := &decodeState{ - ctx: ctx, - objectsByID: make(map[uint64]*objectState), - deferred: make(map[uint64]*pb.Object), - r: r, - stats: stats, + ds := decodeState{ + ctx: ctx, + r: r, + types: makeTypeDecodeDatabase(), + deferred: make(map[objectID]wire.Object), } // Attempt our decode. - return ds.safely(func() { - ds.Deserialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + ds.Load(reflect.ValueOf(rootPtr).Elem()) }) + return ds.stats, err } -// Fns are the state dispatch functions. -type Fns struct { - // Save is a function like Save(concreteType, Map). - Save interface{} - - // Load is a function like Load(concreteType, Map). - Load interface{} +// Sink is used for Type.StateSave. +type Sink struct { + internal objectEncoder } -// Save executes the save function. -func (fns *Fns) invokeSave(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// Save adds the given object to the map. +// +// You should pass always pointers to the object you are saving. For example: +// +// type X struct { +// A int +// B *int +// } +// +// func (x *X) StateTypeInfo(m Sink) state.TypeInfo { +// return state.TypeInfo{ +// Name: "pkg.X", +// Fields: []string{ +// "A", +// "B", +// }, +// } +// } +// +// func (x *X) StateSave(m Sink) { +// m.Save(0, &x.A) // Field is A. +// m.Save(1, &x.B) // Field is B. +// } +// +// func (x *X) StateLoad(m Source) { +// m.Load(0, &x.A) // Field is A. +// m.Load(1, &x.B) // Field is B. +// } +func (s Sink) Save(slot int, objPtr interface{}) { + s.internal.save(slot, reflect.ValueOf(objPtr).Elem()) } -// Load executes the load function. -func (fns *Fns) invokeLoad(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// SaveValue adds the given object value to the map. +// +// This should be used for values where pointers are not available, or casts +// are required during Save/Load. +// +// For example, if we want to cast external package type P.Foo to int64: +// +// func (x *X) StateSave(m Sink) { +// m.SaveValue(0, "A", int64(x.A)) +// } +// +// func (x *X) StateLoad(m Source) { +// m.LoadValue(0, new(int64), func(x interface{}) { +// x.A = P.Foo(x.(int64)) +// }) +// } +func (s Sink) SaveValue(slot int, obj interface{}) { + s.internal.save(slot, reflect.ValueOf(obj)) } -// validateStateFn ensures types are correct. -func validateStateFn(fn interface{}, typ reflect.Type) bool { - fnTyp := reflect.TypeOf(fn) - if fnTyp.Kind() != reflect.Func { - return false - } - if fnTyp.NumIn() != 2 { - return false - } - if fnTyp.NumOut() != 0 { - return false - } - if fnTyp.In(0) != typ { - return false - } - if fnTyp.In(1) != reflect.TypeOf(Map{}) { - return false - } - return true +// Context returns the context object provided at save time. +func (s Sink) Context() context.Context { + return s.internal.es.ctx } -// Validate validates all state functions. -func (fns *Fns) Validate(typ reflect.Type) bool { - return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ) +// Type is an interface that must be implemented by Struct objects. This allows +// these objects to be serialized while minimizing runtime reflection required. +// +// All these methods can be automatically generated by the go_statify tool. +type Type interface { + // StateTypeName returns the type's name. + // + // This is used for matching type information during encoding and + // decoding, as well as dynamic interface dispatch. This should be + // globally unique. + StateTypeName() string + + // StateFields returns information about the type. + // + // Fields is the set of fields for the object. Calls to Sink.Save and + // Source.Load must be made in-order with respect to these fields. + // + // This will be called at most once per serialization. + StateFields() []string } -type typeDatabase struct { - // nameToType is a forward lookup table. - nameToType map[string]reflect.Type - - // typeToName is the reverse lookup table. - typeToName map[reflect.Type]string +// SaverLoader must be implemented by struct types. +type SaverLoader interface { + // StateSave saves the state of the object to the given Map. + StateSave(Sink) - // typeToFns is the function lookup table. - typeToFns map[reflect.Type]Fns + // StateLoad loads the state of the object. + StateLoad(Source) } -// registeredTypes is a database used for SaveInterface and LoadInterface. -var registeredTypes = typeDatabase{ - nameToType: make(map[string]reflect.Type), - typeToName: make(map[reflect.Type]string), - typeToFns: make(map[reflect.Type]Fns), +// Source is used for Type.StateLoad. +type Source struct { + internal objectDecoder } -// register registers a type under the given name. This will generally be -// called via init() methods, and therefore uses panic to propagate errors. -func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) { - // We can't allow name collisions. - if ot, ok := t.nameToType[name]; ok { - panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name())) - } - - // Or multiple registrations. - if on, ok := t.typeToName[typ]; ok { - panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on)) - } - - t.nameToType[name] = typ - t.typeToName[typ] = name - t.typeToFns[typ] = fns +// Load loads the given object passed as a pointer.. +// +// See Sink.Save for an example. +func (s Source) Load(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), false, nil) } -// lookupType finds a type given a name. -func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) { - typ, ok := t.nameToType[name] - return typ, ok +// LoadWait loads the given objects from the map, and marks it as requiring all +// AfterLoad executions to complete prior to running this object's AfterLoad. +// +// See Sink.Save for an example. +func (s Source) LoadWait(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), true, nil) } -// lookupName finds a name given a type. -func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) { - name, ok := t.typeToName[typ] - return name, ok +// LoadValue loads the given object value from the map. +// +// See Sink.SaveValue for an example. +func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) { + o := reflect.ValueOf(objPtr) + s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) }) } -// lookupFns finds functions given a type. -func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) { - fns, ok := t.typeToFns[typ] - return fns, ok +// AfterLoad schedules a function execution when all objects have been +// allocated and their automated loading and customized load logic have been +// executed. fn will not be executed until all of current object's +// dependencies' AfterLoad() logic, if exist, have been executed. +func (s Source) AfterLoad(fn func()) { + s.internal.afterLoad(fn) } -// Register must be called for any interface implementation types that -// implements Loader. -// -// Register should be called either immediately after startup or via init() -// methods. Double registration of either names or types will result in a panic. -// -// No synchronization is provided; this should only be called in init. -// -// Example usage: -// -// state.Register("Foo", (*Foo)(nil), state.Fns{ -// Save: (*Foo).Save, -// Load: (*Foo).Load, -// }) -// -func Register(name string, instance interface{}, fns Fns) { - registeredTypes.register(name, reflect.TypeOf(instance), fns) +// Context returns the context object provided at load time. +func (s Source) Context() context.Context { + return s.internal.ds.ctx } // IsZeroValue checks if the given value is the zero value. @@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool { return val == nil || reflect.ValueOf(val).Elem().IsZero() } -// step captures one encoding / decoding step. On each step, there is up to one -// choice made, which is captured by non-nil param. We intentionally do not -// eagerly create the final path string, as that will only be needed upon panic. -type step struct { - // dereference indicate if the current object is obtained by - // dereferencing a pointer. - dereference bool - - // format is the formatting string that takes param below, if - // non-nil. For example, in array indexing case, we have "[%d]". - format string - - // param stores the choice made at the current encoding / decoding step. - // For eaxmple, in array indexing case, param stores the index. When no - // choice is made, e.g. dereference, param should be nil. - param interface{} -} - -// recoverable is the state encoding / decoding panic recovery facility. It is -// also used to store encoding / decoding steps as well as the reference to the -// original queued object from which the current object is dispatched. The -// complete encoding / decoding path is synthesised from the steps in all queued -// objects leading to the current object. -type recoverable struct { - from *recoverable - steps []step +// Failf is a wrapper around panic that should be used to generate errors that +// can be caught during saving and loading. +func Failf(fmtStr string, v ...interface{}) { + panic(fmt.Errorf(fmtStr, v...)) } -// push enters a new context level. -func (sr *recoverable) push(dereference bool, format string, param interface{}) { - sr.steps = append(sr.steps, step{dereference, format, param}) -} - -// pop exits the current context level. -func (sr *recoverable) pop() { - if len(sr.steps) <= 1 { - return - } - sr.steps = sr.steps[:len(sr.steps)-1] -} - -// path returns the complete encoding / decoding path from root. This is only -// called upon panic. -func (sr *recoverable) path() string { - if sr.from == nil { - return "root" - } - p := sr.from.path() - for _, s := range sr.steps { - if s.dereference { - p = fmt.Sprintf("*(%s)", p) - } - if s.param == nil { - p += s.format - } else { - p += fmt.Sprintf(s.format, s.param) - } - } - return p -} - -func (sr *recoverable) copy() recoverable { - return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)} -} - -// safely executes the given function, catching a panic and unpacking as an error. +// safely executes the given function, catching a panic and unpacking as an +// error. // // The error flow through the state package uses panic and recover. There are // two important reasons for this: @@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable { // method doesn't add a lot of value. If there are specific error conditions // that you'd like to handle, you should add appropriate functionality to // objects themselves prior to calling Save() and Load(). -func (sr *recoverable) safely(fn func()) (err error) { +func safely(fn func()) (err error) { defer func() { if r := recover(); r != nil { + if es, ok := r.(*ErrState); ok { + err = es // Propagate. + return + } + + // Build a new state error. es := new(ErrState) if e, ok := r.(error); ok { es.err = e @@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) { es.err = fmt.Errorf("%v", r) } - es.path = sr.path() - // Make a stack. We don't know how big it will be ahead // of time, but want to make sure we get the whole // thing. So we just do a stupid brute force approach. diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go new file mode 100644 index 000000000..4281aed6d --- /dev/null +++ b/pkg/state/state_norace.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package state + +var raceEnabled = false diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go new file mode 100644 index 000000000..8232981ce --- /dev/null +++ b/pkg/state/state_race.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build race + +package state + +var raceEnabled = true diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go deleted file mode 100644 index d7221e9e8..000000000 --- a/pkg/state/state_test.go +++ /dev/null @@ -1,721 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "bytes" - "context" - "io/ioutil" - "math" - "reflect" - "testing" -) - -// TestCase is used to define a single success/failure testcase of -// serialization of a set of objects. -type TestCase struct { - // Name is the name of the test case. - Name string - - // Objects is the list of values to serialize. - Objects []interface{} - - // Fail is whether the test case is supposed to fail or not. - Fail bool -} - -// runTest runs all testcases. -func runTest(t *testing.T, tests []TestCase) { - for _, test := range tests { - t.Logf("TEST %s:", test.Name) - for i, root := range test.Objects { - t.Logf(" case#%d: %#v", i, root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Save failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Save failed as expected: %v", err) - continue - } - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Load failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Load failed as expected: %v", err) - continue - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { - t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) - continue - } else if !eq { - t.Logf(" PASS: Object different as expected.") - continue - } - - // Everything went okay. Is that good? - if test.Fail { - t.Errorf(" FAIL: Unexpected success.") - } else { - t.Logf(" PASS: Success.") - } - } - } -} - -// dumbStruct is a struct which does not implement the loader/saver interface. -// We expect that serialization of this struct will fail. -type dumbStruct struct { - A int - B int -} - -// smartStruct is a struct which does implement the loader/saver interface. -// We expect that serialization of this struct will succeed. -type smartStruct struct { - A int - B int -} - -func (s *smartStruct) save(m Map) { - m.Save("A", &s.A) - m.Save("B", &s.B) -} - -func (s *smartStruct) load(m Map) { - m.Load("A", &s.A) - m.Load("B", &s.B) -} - -// valueLoadStruct uses a value load. -type valueLoadStruct struct { - v int -} - -func (v *valueLoadStruct) save(m Map) { - m.SaveValue("v", v.v) -} - -func (v *valueLoadStruct) load(m Map) { - m.LoadValue("v", new(int), func(value interface{}) { - v.v = value.(int) - }) -} - -// afterLoadStruct has an AfterLoad function. -type afterLoadStruct struct { - v int -} - -func (a *afterLoadStruct) save(m Map) { -} - -func (a *afterLoadStruct) load(m Map) { - m.AfterLoad(func() { - a.v++ - }) -} - -// genericContainer is a generic dispatcher. -type genericContainer struct { - v interface{} -} - -func (g *genericContainer) save(m Map) { - m.Save("v", &g.v) -} - -func (g *genericContainer) load(m Map) { - m.Load("v", &g.v) -} - -// sliceContainer is a generic slice. -type sliceContainer struct { - v []interface{} -} - -func (s *sliceContainer) save(m Map) { - m.Save("v", &s.v) -} - -func (s *sliceContainer) load(m Map) { - m.Load("v", &s.v) -} - -// mapContainer is a generic map. -type mapContainer struct { - v map[int]interface{} -} - -func (mc *mapContainer) save(m Map) { - m.Save("v", &mc.v) -} - -func (mc *mapContainer) load(m Map) { - // Some of the test cases below assume legacy behavior wherein maps - // will automatically inherit dependencies. - m.LoadWait("v", &mc.v) -} - -// dumbMap is a map which does not implement the loader/saver interface. -// Serialization of this map will default to the standard encode/decode logic. -type dumbMap map[string]int - -// pointerStruct contains various pointers, shared and non-shared, and pointers -// to pointers. We expect that serialization will respect the structure. -type pointerStruct struct { - A *int - B *int - C *int - D *int - - AA **int - BB **int -} - -func (p *pointerStruct) save(m Map) { - m.Save("A", &p.A) - m.Save("B", &p.B) - m.Save("C", &p.C) - m.Save("D", &p.D) - m.Save("AA", &p.AA) - m.Save("BB", &p.BB) -} - -func (p *pointerStruct) load(m Map) { - m.Load("A", &p.A) - m.Load("B", &p.B) - m.Load("C", &p.C) - m.Load("D", &p.D) - m.Load("AA", &p.AA) - m.Load("BB", &p.BB) -} - -// testInterface is a trivial interface example. -type testInterface interface { - Foo() -} - -// testImpl is a trivial implementation of testInterface. -type testImpl struct { -} - -// Foo satisfies testInterface. -func (t *testImpl) Foo() { -} - -// testImpl is trivially serializable. -func (t *testImpl) save(m Map) { -} - -// testImpl is trivially serializable. -func (t *testImpl) load(m Map) { -} - -// testI demonstrates interface dispatching. -type testI struct { - I testInterface -} - -func (t *testI) save(m Map) { - m.Save("I", &t.I) -} - -func (t *testI) load(m Map) { - m.Load("I", &t.I) -} - -// cycleStruct is used to implement basic cycles. -type cycleStruct struct { - c *cycleStruct -} - -func (c *cycleStruct) save(m Map) { - m.Save("c", &c.c) -} - -func (c *cycleStruct) load(m Map) { - m.Load("c", &c.c) -} - -// badCycleStruct actually has deadlocking dependencies. -// -// This should pass if b.b = {nil|b} and fail otherwise. -type badCycleStruct struct { - b *badCycleStruct -} - -func (b *badCycleStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *badCycleStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(func() { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - }) -} - -// emptyStructPointer points to an empty struct. -type emptyStructPointer struct { - nothing *struct{} -} - -func (e *emptyStructPointer) save(m Map) { - m.Save("nothing", &e.nothing) -} - -func (e *emptyStructPointer) load(m Map) { - m.Load("nothing", &e.nothing) -} - -// truncateInteger truncates an integer. -type truncateInteger struct { - v int64 - v2 int32 -} - -func (t *truncateInteger) save(m Map) { - t.v2 = int32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = int64(t.v2) -} - -// truncateUnsignedInteger truncates an unsigned integer. -type truncateUnsignedInteger struct { - v uint64 - v2 uint32 -} - -func (t *truncateUnsignedInteger) save(m Map) { - t.v2 = uint32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateUnsignedInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = uint64(t.v2) -} - -// truncateFloat truncates a floating point number. -type truncateFloat struct { - v float64 - v2 float32 -} - -func (t *truncateFloat) save(m Map) { - t.v2 = float32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateFloat) load(m Map) { - m.Load("v", &t.v2) - t.v = float64(t.v2) -} - -func TestTypes(t *testing.T) { - // x and y are basic integers, while xp points to x. - x := 1 - y := 2 - xp := &x - - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - // bs is a single object cycle. - bs := badCycleStruct{nil} - bs.b = &bs - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - // regular nils. - var ( - nilmap dumbMap - nilslice []byte - ) - - // embed points to embedded fields. - embed1 := pointerStruct{} - embed1.AA = &embed1.A - embed2 := pointerStruct{} - embed2.BB = &embed2.B - - // es1 contains two structs pointing to the same empty struct. - es := emptyStructPointer{new(struct{})} - es1 := []emptyStructPointer{es, es} - - tests := []TestCase{ - { - Name: "bool", - Objects: []interface{}{ - true, - false, - }, - }, - { - Name: "integers", - Objects: []interface{}{ - int(0), - int(1), - int(-1), - int8(0), - int8(1), - int8(-1), - int16(0), - int16(1), - int16(-1), - int32(0), - int32(1), - int32(-1), - int64(0), - int64(1), - int64(-1), - }, - }, - { - Name: "unsigned integers", - Objects: []interface{}{ - uint(0), - uint(1), - uint8(0), - uint8(1), - uint16(0), - uint16(1), - uint32(1), - uint64(0), - uint64(1), - }, - }, - { - Name: "strings", - Objects: []interface{}{ - "", - "foo", - "bar", - "\xa0", - }, - }, - { - Name: "slices", - Objects: []interface{}{ - []int{-1, 0, 1}, - []*int{&x, &x, &x}, - []int{1, 2, 3}[0:1], - []int{1, 2, 3}[1:2], - make([]byte, 32), - make([]byte, 32)[:16], - make([]byte, 32)[:16:20], - nilslice, - }, - }, - { - Name: "arrays", - Objects: []interface{}{ - &[1048576]bool{false, true, false, true}, - &[1048576]uint8{0, 1, 2, 3}, - &[1048576]byte{0, 1, 2, 3}, - &[1048576]uint16{0, 1, 2, 3}, - &[1048576]uint{0, 1, 2, 3}, - &[1048576]uint32{0, 1, 2, 3}, - &[1048576]uint64{0, 1, 2, 3}, - &[1048576]uintptr{0, 1, 2, 3}, - &[1048576]int8{0, -1, -2, -3}, - &[1048576]int16{0, -1, -2, -3}, - &[1048576]int32{0, -1, -2, -3}, - &[1048576]int64{0, -1, -2, -3}, - &[1048576]float32{0, 1.1, 2.2, 3.3}, - &[1048576]float64{0, 1.1, 2.2, 3.3}, - }, - }, - { - Name: "pointers", - Objects: []interface{}{ - &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, - &pointerStruct{}, - }, - }, - { - Name: "empty struct", - Objects: []interface{}{ - struct{}{}, - }, - }, - { - Name: "unenlightened structs", - Objects: []interface{}{ - &dumbStruct{A: 1, B: 2}, - }, - Fail: true, - }, - { - Name: "enlightened structs", - Objects: []interface{}{ - &smartStruct{A: 1, B: 2}, - }, - }, - { - Name: "load-hooks", - Objects: []interface{}{ - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }, - }, - { - Name: "maps", - Objects: []interface{}{ - dumbMap{"a": -1, "b": 0, "c": 1}, - map[smartStruct]int{{}: 0, {A: 1}: 1}, - nilmap, - &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, - }, - }, - { - Name: "interfaces", - Objects: []interface{}{ - &testI{&testImpl{}}, - &testI{nil}, - &testI{(*testImpl)(nil)}, - }, - }, - { - Name: "unregistered-interfaces", - Objects: []interface{}{ - &genericContainer{v: afterLoadStruct{v: 1}}, - &genericContainer{v: valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, - }, - Fail: true, - }, - { - Name: "cycles", - Objects: []interface{}{ - &cs, - &cs1, - &cycleStruct{&cs1}, - &cycleStruct{&cs}, - &badCycleStruct{nil}, - &bs, - }, - }, - { - Name: "deadlock", - Objects: []interface{}{ - &bs1, - }, - Fail: true, - }, - { - Name: "embed", - Objects: []interface{}{ - &embed1, - &embed2, - }, - Fail: true, - }, - { - Name: "empty structs", - Objects: []interface{}{ - new(struct{}), - es, - es1, - }, - }, - { - Name: "truncated okay", - Objects: []interface{}{ - &truncateInteger{v: 1}, - &truncateUnsignedInteger{v: 1}, - &truncateFloat{v: 1.0}, - }, - }, - { - Name: "truncated bad", - Objects: []interface{}{ - &truncateInteger{v: math.MaxInt32 + 1}, - &truncateUnsignedInteger{v: math.MaxUint32 + 1}, - &truncateFloat{v: math.MaxFloat32 * 2}, - }, - Fail: true, - }, - } - - runTest(t, tests) -} - -// benchStruct is used for benchmarking. -type benchStruct struct { - b *benchStruct - - // Dummy data is included to ensure that these objects are large. - // This is to detect possible regression when registering objects. - _ [4096]byte -} - -func (b *benchStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *benchStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(b.afterLoad) -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} - -// buildObject builds a benchmark object. -func buildObject(n int) (b *benchStruct) { - for i := 0; i < n; i++ { - b = &benchStruct{b: b} - } - return -} - -func BenchmarkEncoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var stats Stats - b.StartTimer() - if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { - b.Errorf("save failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func BenchmarkDecoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var newBS benchStruct - buf := &bytes.Buffer{} - if err := Save(context.Background(), buf, bs, nil); err != nil { - b.Errorf("save failed: %v", err) - } - var stats Stats - b.StartTimer() - if err := Load(context.Background(), buf, &newBS, &stats); err != nil { - b.Errorf("load failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func init() { - Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ - Save: (*smartStruct).save, - Load: (*smartStruct).load, - }) - Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ - Save: (*afterLoadStruct).save, - Load: (*afterLoadStruct).load, - }) - Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ - Save: (*valueLoadStruct).save, - Load: (*valueLoadStruct).load, - }) - Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ - Save: (*genericContainer).save, - Load: (*genericContainer).load, - }) - Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ - Save: (*sliceContainer).save, - Load: (*sliceContainer).load, - }) - Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ - Save: (*mapContainer).save, - Load: (*mapContainer).load, - }) - Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ - Save: (*pointerStruct).save, - Load: (*pointerStruct).load, - }) - Register("stateTest.testImpl", (*testImpl)(nil), Fns{ - Save: (*testImpl).save, - Load: (*testImpl).load, - }) - Register("stateTest.testI", (*testI)(nil), Fns{ - Save: (*testI).save, - Load: (*testI).load, - }) - Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ - Save: (*cycleStruct).save, - Load: (*cycleStruct).load, - }) - Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ - Save: (*badCycleStruct).save, - Load: (*badCycleStruct).load, - }) - Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ - Save: (*emptyStructPointer).save, - Load: (*emptyStructPointer).load, - }) - Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ - Save: (*truncateInteger).save, - Load: (*truncateInteger).load, - }) - Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ - Save: (*truncateUnsignedInteger).save, - Load: (*truncateUnsignedInteger).load, - }) - Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ - Save: (*truncateFloat).save, - Load: (*truncateFloat).load, - }) - Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ - Save: (*benchStruct).save, - Load: (*benchStruct).load, - }) -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e7581c09b..d6c89c7e9 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/compressio", + "//pkg/state/wire", ], ) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index c0f4c4954..bdfb800fb 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -57,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" + "gvisor.dev/gvisor/pkg/state/wire" ) // keySize is the AES-256 key length. @@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size // ErrMetadataInvalid is returned if passed metadata is invalid. var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _") +// WriteCloser is an io.Closer and wire.Writer. +type WriteCloser interface { + wire.Writer + io.Closer +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. -func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) { +func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) { if metadata == nil { metadata = make(map[string]string) } @@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } // NewReader returns a reader for a statefile. -func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { +func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) { // Read the metadata with the hash. h := hmac.New(sha256.New, key) metadata, err := metadata(r, h) @@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { } // Wrap in compression. - rc, err := compressio.NewReader(r, key) + cr, err := compressio.NewReader(r, key) if err != nil { return nil, nil, err } - return rc, metadata, nil + return cr, metadata, nil } diff --git a/pkg/state/stats.go b/pkg/state/stats.go index eb51cda47..eaec664a1 100644 --- a/pkg/state/stats.go +++ b/pkg/state/stats.go @@ -17,7 +17,6 @@ package state import ( "bytes" "fmt" - "reflect" "sort" "time" ) @@ -35,92 +34,81 @@ type statEntry struct { // All exported receivers accept nil. type Stats struct { // byType contains a breakdown of time spent by type. - byType map[reflect.Type]*statEntry + // + // This is indexed *directly* by typeID, including zero. + byType []statEntry // stack contains objects in progress. - stack []reflect.Type + stack []typeID + + // names contains type names. + // + // This is also indexed *directly* by typeID, including zero, which we + // hard-code as "state.default". This is only resolved by calling fini + // on the stats object. + names []string // last is the last start time. last time.Time } -// sample adds the samples to the given object. -func (s *Stats) sample(typ reflect.Type) { - now := time.Now() - s.byType[typ].total += now.Sub(s.last) - s.last = now +// init initializes statistics. +func (s *Stats) init() { + s.last = time.Now() + s.stack = append(s.stack, 0) } -// Add adds a sample count. -func (s *Stats) Add(obj reflect.Value) { - if s == nil { - return - } - if s.byType == nil { - s.byType = make(map[reflect.Type]*statEntry) - } - typ := obj.Type() - entry, ok := s.byType[typ] - if !ok { - entry = new(statEntry) - s.byType[typ] = entry +// fini finalizes statistics. +func (s *Stats) fini(resolve func(id typeID) string) { + s.done() + + // Resolve all type names. + s.names = make([]string, len(s.byType)) + s.names[0] = "state.default" // See above. + for id := typeID(1); int(id) < len(s.names); id++ { + s.names[id] = resolve(id) } - entry.count++ } -// Remove removes a sample count. It should only be called after a previous -// Add(). -func (s *Stats) Remove(obj reflect.Value) { - if s == nil { - return +// sample adds the samples to the given object. +func (s *Stats) sample(id typeID) { + now := time.Now() + if len(s.byType) <= int(id) { + // Allocate all the missing entries in one fell swoop. + s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...) } - typ := obj.Type() - entry := s.byType[typ] - entry.count-- + s.byType[id].total += now.Sub(s.last) + s.last = now } -// Start starts a sample. -func (s *Stats) Start(obj reflect.Value) { - if s == nil { - return - } - if len(s.stack) > 0 { - last := s.stack[len(s.stack)-1] - s.sample(last) - } else { - // First time sample. - s.last = time.Now() - } - s.stack = append(s.stack, obj.Type()) +// start starts a sample. +func (s *Stats) start(id typeID) { + last := s.stack[len(s.stack)-1] + s.sample(last) + s.stack = append(s.stack, id) } -// Done finishes the current sample. -func (s *Stats) Done() { - if s == nil { - return - } +// done finishes the current sample. +func (s *Stats) done() { last := s.stack[len(s.stack)-1] s.sample(last) + s.byType[last].count++ s.stack = s.stack[:len(s.stack)-1] } type sliceEntry struct { - typ reflect.Type + name string entry *statEntry } // String returns a table representation of the stats. func (s *Stats) String() string { - if s == nil || len(s.byType) == 0 { - return "(no data)" - } - // Build a list of stat entries. ss := make([]sliceEntry, 0, len(s.byType)) - for typ, entry := range s.byType { + for id := 0; id < len(s.names); id++ { ss = append(ss, sliceEntry{ - typ: typ, - entry: entry, + name: s.names[id], + entry: &s.byType[id], }) } @@ -136,17 +124,22 @@ func (s *Stats) String() string { total time.Duration ) buf.WriteString("\n") - buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type")) - buf.WriteString("-------------+----------+----------+-------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type")) + buf.WriteString("-----------------+----------+------------------+----------------\n") for _, se := range ss { + if se.entry.count == 0 { + // Since we store all types linearly, we are not + // guaranteed that any entry actually has time. + continue + } count += se.entry.count total += se.entry.total per := se.entry.total / time.Duration(se.entry.count) - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n", - se.entry.total, se.entry.count, per, se.typ.String())) + buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n", + se.entry.total, se.entry.count, per, se.name)) } - buf.WriteString("-------------+----------+----------+-------------\n") - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]", + buf.WriteString("-----------------+----------+------------------+----------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]", total, count, total/time.Duration(count))) return string(buf.Bytes()) } diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD new file mode 100644 index 000000000..9297cafbe --- /dev/null +++ b/pkg/state/tests/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tests", + srcs = [ + "array.go", + "bench.go", + "integer.go", + "load.go", + "map.go", + "register.go", + "struct.go", + "tests.go", + ], + deps = [ + "//pkg/state", + "//pkg/state/pretty", + ], +) + +go_test( + name = "tests_test", + size = "small", + srcs = [ + "array_test.go", + "bench_test.go", + "bool_test.go", + "float_test.go", + "integer_test.go", + "load_test.go", + "map_test.go", + "register_test.go", + "string_test.go", + "struct_test.go", + ], + library = ":tests", + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go new file mode 100644 index 000000000..0972a80e7 --- /dev/null +++ b/pkg/state/tests/array.go @@ -0,0 +1,35 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type arrayContainer struct { + v [1]interface{} +} + +// +stateify savable +type arrayPtrContainer struct { + v *[1]interface{} +} + +// +stateify savable +type sliceContainer struct { + v []interface{} +} + +// +stateify savable +type slicePtrContainer struct { + v *[]interface{} +} diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go new file mode 100644 index 000000000..a347b2947 --- /dev/null +++ b/pkg/state/tests/array_test.go @@ -0,0 +1,134 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allArrayPrimitives = []interface{}{ + [1]bool{}, + [1]bool{true}, + [2]bool{false, true}, + [1]int{}, + [1]int{1}, + [2]int{0, 1}, + [1]int8{}, + [1]int8{1}, + [2]int8{0, 1}, + [1]int16{}, + [1]int16{1}, + [2]int16{0, 1}, + [1]int32{}, + [1]int32{1}, + [2]int32{0, 1}, + [1]int64{}, + [1]int64{1}, + [2]int64{0, 1}, + [1]uint{}, + [1]uint{1}, + [2]uint{0, 1}, + [1]uintptr{}, + [1]uintptr{1}, + [2]uintptr{0, 1}, + [1]uint8{}, + [1]uint8{1}, + [2]uint8{0, 1}, + [1]uint16{}, + [1]uint16{1}, + [2]uint16{0, 1}, + [1]uint32{}, + [1]uint32{1}, + [2]uint32{0, 1}, + [1]uint64{}, + [1]uint64{1}, + [2]uint64{0, 1}, + [1]string{}, + [1]string{""}, + [1]string{nonEmptyString}, + [2]string{"", nonEmptyString}, +} + +func TestArrayPrimitives(t *testing.T) { + runTestCases(t, false, "plain", flatten(allArrayPrimitives)) + runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives)))) +} + +func TestSlices(t *testing.T) { + var allSlices = flatten( + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + return v.Slice(0, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the pure "nil" value for the slice. + return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true + } + return v.Slice(1, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the zero-valued slice. + return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true + } + return v.Slice(0, v.Len()-1).Interface(), true + }), + ) + runTestCases(t, false, "plain", allSlices) + runTestCases(t, false, "pointers", pointersTo(allSlices)) + runTestCases(t, false, "interfaces", interfacesTo(allSlices)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices))) +} + +func TestArrayContainers(t *testing.T) { + var ( + emptyArray [1]interface{} + fullArray [1]interface{} + ) + fullArray[0] = &emptyArray + runTestCases(t, false, "", []interface{}{ + arrayContainer{v: emptyArray}, + arrayContainer{v: fullArray}, + arrayPtrContainer{v: nil}, + arrayPtrContainer{v: &emptyArray}, + arrayPtrContainer{v: &fullArray}, + }) +} + +func TestSliceContainers(t *testing.T) { + var ( + nilSlice []interface{} + emptySlice = make([]interface{}, 0) + fullSlice = []interface{}{nil} + ) + runTestCases(t, false, "", []interface{}{ + sliceContainer{v: nilSlice}, + sliceContainer{v: emptySlice}, + sliceContainer{v: fullSlice}, + slicePtrContainer{v: nil}, + slicePtrContainer{v: &nilSlice}, + slicePtrContainer{v: &emptySlice}, + slicePtrContainer{v: &fullSlice}, + }) +} diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go new file mode 100644 index 000000000..40869cdfb --- /dev/null +++ b/pkg/state/tests/bench.go @@ -0,0 +1,24 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type benchStruct struct { + B *benchStruct // Must be exported for gob. +} + +func (b *benchStruct) afterLoad() { + // Do nothing, just force scheduling. +} diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go new file mode 100644 index 000000000..7e102c907 --- /dev/null +++ b/pkg/state/tests/bench_test.go @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +// buildPtrObject builds a benchmark object. +func buildPtrObject(n int) interface{} { + b := new(benchStruct) + for i := 0; i < n; i++ { + b = &benchStruct{B: b} + } + return b +} + +// buildMapObject builds a benchmark object. +func buildMapObject(n int) interface{} { + b := new(benchStruct) + m := make(map[int]*benchStruct) + for i := 0; i < n; i++ { + m[i] = b + } + return &m +} + +// buildSliceObject builds a benchmark object. +func buildSliceObject(n int) interface{} { + b := new(benchStruct) + s := make([]*benchStruct, 0, n) + for i := 0; i < n; i++ { + s = append(s, b) + } + return &s +} + +var allObjects = map[string]struct { + New func(int) interface{} +}{ + "ptr": { + New: buildPtrObject, + }, + "map": { + New: buildMapObject, + }, + "slice": { + New: buildSliceObject, + }, +} + +func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) { + // maxSize is the maximum size of an individual object below. For an N + // larger than this, we start to return multiple objects. + const maxSize = 1024 + if n <= maxSize { + return 1, fn(n) + } + iters = (n + maxSize - 1) / maxSize + return iters, fn(maxSize) +} + +// gobSave is a version of save using gob (no stats available). +func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) { + enc := gob.NewEncoder(w) + err = enc.Encode(v) + return +} + +// gobLoad is a version of load using gob (no stats available). +func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) { + dec := gob.NewDecoder(r) + err = dec.Decode(v) + return +} + +var allAlgos = map[string]struct { + Save func(context.Context, wire.Writer, interface{}) (state.Stats, error) + Load func(context.Context, wire.Reader, interface{}) (state.Stats, error) + MaxPtr int +}{ + "state": { + Save: state.Save, + Load: state.Load, + }, + "gob": { + Save: gobSave, + Load: gobLoad, + }, +} + +func BenchmarkEncoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + b.ReportAllocs() + b.StartTimer() + for i := 0; i < n; i++ { + if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil { + b.Errorf("save failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} + +func BenchmarkDecoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + buf := new(bytes.Buffer) + if _, err := algoInfo.Save(context.Background(), buf, v); err != nil { + b.Errorf("save failed: %v", err) + } + b.ReportAllocs() + b.StartTimer() + var r bytes.Reader + for i := 0; i < n; i++ { + r.Reset(buf.Bytes()) + if _, err := algoInfo.Load(context.Background(), &r, v); err != nil { + b.Errorf("load failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go new file mode 100644 index 000000000..e17cfacf9 --- /dev/null +++ b/pkg/state/tests/bool_test.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +var allBools = []bool{ + true, + false, +} + +func TestBool(t *testing.T) { + runTestCases(t, false, "plain", flatten(allBools)) + runTestCases(t, false, "pointers", pointersTo(flatten(allBools))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools)))) +} diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go new file mode 100644 index 000000000..3e89edd9c --- /dev/null +++ b/pkg/state/tests/float_test.go @@ -0,0 +1,118 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var safeFloat32s = []float32{ + float32(0.0), + float32(1.0), + float32(-1.0), + float32(math.Inf(1)), + float32(math.Inf(-1)), +} + +var allFloat32s = append(safeFloat32s, float32(math.NaN())) + +var safeFloat64s = []float64{ + float64(0.0), + float64(1.0), + float64(-1.0), + math.Inf(1), + math.Inf(-1), +} + +var allFloat64s = append(safeFloat64s, math.NaN()) + +func TestFloat(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allFloat32s, + allFloat64s, + )) + // See checkEqual for why NaNs are missing. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfaces", interfacesTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten( + safeFloat32s, + safeFloat64s, + )))) +} + +const onlyDouble float64 = 1.0000000000000002 + +func TestFloatTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingFloat32{save: onlyDouble}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingFloat32{save: 1.0}, + }) +} + +var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +func TestComplex(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allComplex64s, + allComplex128s, + )) + // See TestFloat; same issue. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacse", interfacesTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten( + safeComplex64s, + safeComplex128s, + )))) +} + +func TestComplexTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingComplex64{save: complex(onlyDouble, onlyDouble)}, + truncatingComplex64{save: complex(1.0, onlyDouble)}, + truncatingComplex64{save: complex(onlyDouble, 1.0)}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingComplex64{save: complex(1.0, 1.0)}, + }) +} diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go new file mode 100644 index 000000000..ca403eed1 --- /dev/null +++ b/pkg/state/tests/integer.go @@ -0,0 +1,163 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +// +stateify type +type truncatingUint8 struct { + save uint64 + load uint8 `state:"nosave"` +} + +func (t *truncatingUint8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint8)(nil) + +// +stateify type +type truncatingUint16 struct { + save uint64 + load uint16 `state:"nosave"` +} + +func (t *truncatingUint16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint16)(nil) + +// +stateify type +type truncatingUint32 struct { + save uint64 + load uint32 `state:"nosave"` +} + +func (t *truncatingUint32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint32)(nil) + +// +stateify type +type truncatingInt8 struct { + save int64 + load int8 `state:"nosave"` +} + +func (t *truncatingInt8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt8)(nil) + +// +stateify type +type truncatingInt16 struct { + save int64 + load int16 `state:"nosave"` +} + +func (t *truncatingInt16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt16)(nil) + +// +stateify type +type truncatingInt32 struct { + save int64 + load int32 `state:"nosave"` +} + +func (t *truncatingInt32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt32)(nil) + +// +stateify type +type truncatingFloat32 struct { + save float64 + load float32 `state:"nosave"` +} + +func (t *truncatingFloat32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingFloat32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = float64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingFloat32)(nil) + +// +stateify type +type truncatingComplex64 struct { + save complex128 + load complex64 `state:"nosave"` +} + +func (t *truncatingComplex64) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingComplex64) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = complex128(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingComplex64)(nil) diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go new file mode 100644 index 000000000..d3931c952 --- /dev/null +++ b/pkg/state/tests/integer_test.go @@ -0,0 +1,94 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var ( + allIntTs = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allUintTs = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} +) + +var allInts = flatten( + allIntTs, + allInt8s, + allInt16s, + allInt32s, + allInt64s, +) + +var allUints = flatten( + allUintTs, + allUintptrs, + allUint8s, + allUint16s, + allUint32s, + allUint64s, +) + +func TestInt(t *testing.T) { + runTestCases(t, false, "plain", allInts) + runTestCases(t, false, "pointers", pointersTo(allInts)) + runTestCases(t, false, "interfaces", interfacesTo(allInts)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts))) +} + +func TestIntTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingInt8{save: math.MinInt8 - 1}, + truncatingInt16{save: math.MinInt16 - 1}, + truncatingInt32{save: math.MinInt32 - 1}, + truncatingInt8{save: math.MaxInt8 + 1}, + truncatingInt16{save: math.MaxInt16 + 1}, + truncatingInt32{save: math.MaxInt32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingInt8{save: 1}, + truncatingInt16{save: 1}, + truncatingInt32{save: 1}, + }) +} + +func TestUint(t *testing.T) { + runTestCases(t, false, "plain", allUints) + runTestCases(t, false, "pointers", pointersTo(allUints)) + runTestCases(t, false, "interfaces", interfacesTo(allUints)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints))) +} + +func TestUintTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingUint8{save: math.MaxUint8 + 1}, + truncatingUint16{save: math.MaxUint16 + 1}, + truncatingUint32{save: math.MaxUint32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingUint8{save: 1}, + truncatingUint16{save: 1}, + truncatingUint32{save: 1}, + }) +} diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go new file mode 100644 index 000000000..a8350c0f3 --- /dev/null +++ b/pkg/state/tests/load.go @@ -0,0 +1,61 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type genericContainer struct { + v interface{} +} + +// +stateify savable +type afterLoadStruct struct { + v int `state:"nosave"` +} + +func (a *afterLoadStruct) afterLoad() { + a.v++ +} + +// +stateify savable +type valueLoadStruct struct { + v int `state:".(int64)"` +} + +func (v *valueLoadStruct) saveV() int64 { + return int64(v.v) // Save as int64. +} + +func (v *valueLoadStruct) loadV(value int64) { + v.v = int(value) // Load as int. +} + +// +stateify savable +type cycleStruct struct { + c *cycleStruct +} + +// +stateify savable +type badCycleStruct struct { + b *badCycleStruct `state:"wait"` +} + +func (b *badCycleStruct) afterLoad() { + if b.b != b { + // This is not executable, since AfterLoad requires that the + // object and all dependencies are complete. This should cause + // a deadlock error during load. + panic("badCycleStruct.afterLoad called") + } +} diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go new file mode 100644 index 000000000..1e9794296 --- /dev/null +++ b/pkg/state/tests/load_test.go @@ -0,0 +1,70 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +func TestLoadHooks(t *testing.T) { + runTestCases(t, false, "load-hooks", []interface{}{ + &afterLoadStruct{v: 1}, + &valueLoadStruct{v: 1}, + &genericContainer{v: &afterLoadStruct{v: 1}}, + &genericContainer{v: &valueLoadStruct{v: 1}}, + &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, + }) +} + +func TestCycles(t *testing.T) { + // cs is a single object cycle. + cs := cycleStruct{nil} + cs.c = &cs + + // cs1 and cs2 are in a two object cycle. + cs1 := cycleStruct{nil} + cs2 := cycleStruct{nil} + cs1.c = &cs2 + cs2.c = &cs1 + + runTestCases(t, false, "cycles", []interface{}{ + cs, + cs1, + }) +} + +func TestDeadlock(t *testing.T) { + // bs is a single object cycle. This does not cause deadlock because an + // object cannot wait for itself. + bs := badCycleStruct{nil} + bs.b = &bs + + runTestCases(t, false, "self", []interface{}{ + &bs, + }) + + // bs2 and bs2 are in a deadlocking cycle. + bs1 := badCycleStruct{nil} + bs2 := badCycleStruct{nil} + bs1.b = &bs2 + bs2.b = &bs1 + + runTestCases(t, true, "deadlock", []interface{}{ + &bs1, + }) +} diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go new file mode 100644 index 000000000..db4e548f1 --- /dev/null +++ b/pkg/state/tests/map.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type mapContainer struct { + v map[int]interface{} +} + +// +stateify savable +type mapPtrContainer struct { + v *map[int]interface{} +} + +// +stateify savable +type registeredMapStruct struct{} diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go new file mode 100644 index 000000000..92bf0fc01 --- /dev/null +++ b/pkg/state/tests/map_test.go @@ -0,0 +1,90 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allMapPrimitives = []interface{}{ + bool(true), + int(1), + int8(1), + int16(1), + int32(1), + int64(1), + uint(1), + uintptr(1), + uint8(1), + uint16(1), + uint32(1), + uint64(1), + string(""), + registeredMapStruct{}, +} + +var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives)) + +var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives)) + +var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + return m.Interface() +}) + +var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2))) + return m.Interface() +}) + +func TestMapAliasing(t *testing.T) { + v := make(map[int]int) + ptrToV := &v + aliases := []map[int]int{v, v} + runTestCases(t, false, "", []interface{}{ptrToV, aliases}) +} + +func TestMapsEmpty(t *testing.T) { + runTestCases(t, false, "plain", emptyMaps) + runTestCases(t, false, "pointers", pointersTo(emptyMaps)) + runTestCases(t, false, "interfaces", interfacesTo(emptyMaps)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps))) +} + +func TestMapsFull(t *testing.T) { + runTestCases(t, false, "plain", fullMaps) + runTestCases(t, false, "pointers", pointersTo(fullMaps)) + runTestCases(t, false, "interfaces", interfacesTo(fullMaps)) + runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps))) +} + +func TestMapContainers(t *testing.T) { + var ( + nilMap map[int]interface{} + emptyMap = make(map[int]interface{}) + fullMap = map[int]interface{}{0: nil} + ) + runTestCases(t, false, "", []interface{}{ + mapContainer{v: nilMap}, + mapContainer{v: emptyMap}, + mapContainer{v: fullMap}, + mapPtrContainer{v: nil}, + mapPtrContainer{v: &nilMap}, + mapPtrContainer{v: &emptyMap}, + mapPtrContainer{v: &fullMap}, + }) +} diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go new file mode 100644 index 000000000..074d86315 --- /dev/null +++ b/pkg/state/tests/register.go @@ -0,0 +1,21 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type alreadyRegisteredStruct struct{} + +// +stateify savable +type alreadyRegisteredOther int diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go new file mode 100644 index 000000000..c829753cc --- /dev/null +++ b/pkg/state/tests/register_test.go @@ -0,0 +1,167 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +// faker calls itself whatever is in the name field. +type faker struct { + Name string + Fields []string +} + +func (f *faker) StateTypeName() string { + return f.Name +} + +func (f *faker) StateFields() []string { + return f.Fields +} + +// fakerWithSaverLoader has all it needs. +type fakerWithSaverLoader struct { + faker +} + +func (f *fakerWithSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerWithSaverLoader) StateLoad(m state.Source) {} + +// fakerOther calls itself .. uh, itself? +type fakerOther string + +func (f *fakerOther) StateTypeName() string { + return string(*f) +} + +func (f *fakerOther) StateFields() []string { + return nil +} + +func newFakerOther(name string) *fakerOther { + f := fakerOther(name) + return &f +} + +// fakerOtherBadFields returns non-nil fields. +type fakerOtherBadFields string + +func (f *fakerOtherBadFields) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherBadFields) StateFields() []string { + return []string{string(*f)} +} + +func newFakerOtherBadFields(name string) *fakerOtherBadFields { + f := fakerOtherBadFields(name) + return &f +} + +// fakerOtherSaverLoader implements SaverLoader methods. +type fakerOtherSaverLoader string + +func (f *fakerOtherSaverLoader) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherSaverLoader) StateFields() []string { + return nil +} + +func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {} + +func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader { + f := fakerOtherSaverLoader(name) + return &f +} + +func TestRegisterPrimitives(t *testing.T) { + for _, typeName := range []string{ + "int", + "int8", + "int16", + "int32", + "int64", + "uint", + "uintptr", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + "string", + } { + t.Run("struct/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(&faker{ + Name: typeName, + }) + }) + t.Run("other/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(newFakerOther(typeName)) + }) + } +} + +func TestRegisterBad(t *testing.T) { + const ( + goodName = "foo" + firstField = "a" + secondField = "b" + ) + for name, object := range map[string]state.Type{ + "non-struct-with-fields": newFakerOtherBadFields(goodName), + "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName), + "struct-without-saverloader": &faker{Name: goodName}, + "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()), + "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()), + "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}}, + "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}}, + "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}}, + "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}}, + "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}}, + "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}}, + } { + t.Run(name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering object %#v did not panic", object) + } + }() + state.Register(object) + }) + + } +} diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go new file mode 100644 index 000000000..44f5a562c --- /dev/null +++ b/pkg/state/tests/string_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +const nonEmptyString = "hello world" + +var allStrings = []string{ + "", + nonEmptyString, + "\\0", +} + +func TestString(t *testing.T) { + runTestCases(t, false, "plain", flatten(allStrings)) + runTestCases(t, false, "pointers", pointersTo(flatten(allStrings))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings)))) +} diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go new file mode 100644 index 000000000..bd2c2b399 --- /dev/null +++ b/pkg/state/tests/struct.go @@ -0,0 +1,65 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type unregisteredEmptyStruct struct{} + +// typeOnlyEmptyStruct just implements the state.Type interface. +type typeOnlyEmptyStruct struct{} + +func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" } + +func (*typeOnlyEmptyStruct) StateFields() []string { return nil } + +// +stateify savable +type savableEmptyStruct struct{} + +// +stateify savable +type emptyStructPointer struct { + nothing *struct{} +} + +// +stateify savable +type outerSame struct { + inner inner +} + +// +stateify savable +type outerFieldFirst struct { + inner inner + v int64 +} + +// +stateify savable +type outerFieldSecond struct { + v int64 + inner inner +} + +// +stateify savable +type outerArray struct { + inner [2]inner +} + +// +stateify savable +type inner struct { + v int64 +} + +// +stateify savable +type system struct { + v1 interface{} + v2 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go new file mode 100644 index 000000000..de9d17aa7 --- /dev/null +++ b/pkg/state/tests/struct_test.go @@ -0,0 +1,89 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +func TestEmptyStruct(t *testing.T) { + runTestCases(t, false, "plain", []interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + }) + runTestCases(t, false, "pointers", pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + })) + runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{ + // Only registered types can be dispatched via interfaces. All + // other types should fail, even if it is the empty struct. + savableEmptyStruct{}, + })) + runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + })) + runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{ + savableEmptyStruct{}, + }))) + runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + }))) + + // Ensuring empty struct aliasing works. + es := emptyStructPointer{new(struct{})} + runTestCases(t, false, "empty-struct-pointers", []interface{}{ + emptyStructPointer{}, + es, + []emptyStructPointer{es, es}, // Same pointer. + }) +} + +func TestRegisterTypeOnlyStruct(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Register did not panic") + } + }() + state.Register((*typeOnlyEmptyStruct)(nil)) +} + +func TestEmbeddedPointers(t *testing.T) { + var ( + ofs outerSame + of1 outerFieldFirst + of2 outerFieldSecond + oa outerArray + ) + + runTestCases(t, false, "embedded-pointers", []interface{}{ + system{&ofs, &ofs.inner}, + system{&ofs.inner, &ofs}, + system{&of1, &of1.inner}, + system{&of1.inner, &of1}, + system{&of2, &of2.inner}, + system{&of2.inner, &of2}, + system{&oa, &oa.inner[0]}, + system{&oa, &oa.inner[1]}, + system{&oa.inner[0], &oa}, + system{&oa.inner[1], &oa}, + }) +} diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go new file mode 100644 index 000000000..435a0e9db --- /dev/null +++ b/pkg/state/tests/tests.go @@ -0,0 +1,215 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tests tests the state packages. +package tests + +import ( + "bytes" + "context" + "fmt" + "math" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/pretty" +) + +// discard is an implementation of wire.Writer. +type discard struct{} + +// Write implements wire.Writer.Write. +func (discard) Write(p []byte) (int, error) { return len(p), nil } + +// WriteByte implements wire.Writer.WriteByte. +func (discard) WriteByte(byte) error { return nil } + +// checkEqual checks if two objects are equal. +// +// N.B. This only handles one level of dereferences for NaN. Otherwise we +// would need to fork the entire implementation of reflect.DeepEqual. +func checkEqual(root, loadedValue interface{}) bool { + if reflect.DeepEqual(root, loadedValue) { + return true + } + + // NaN is not equal to itself. We handle the case of raw floating point + // primitives here, but don't handle this case nested. + rf32, ok1 := root.(float32) + lf32, ok2 := loadedValue.(float32) + if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) { + return true + } + rf64, ok1 := root.(float64) + lf64, ok2 := loadedValue.(float64) + if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) { + return true + } + + // Same real for complex numbers. + rc64, ok1 := root.(complex64) + lc64, ok2 := root.(complex64) + if ok1 && ok2 { + return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64)) + } + rc128, ok1 := root.(complex128) + lc128, ok2 := root.(complex128) + if ok1 && ok2 { + return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128)) + } + + return false +} + +// runTestCases runs a test for each object in objects. +func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) { + t.Helper() + for i, root := range objects { + t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) { + t.Logf("Original object:\n%#v", root) + + // Save the passed object. + saveBuffer := &bytes.Buffer{} + saveObjectPtr := reflect.New(reflect.TypeOf(root)) + saveObjectPtr.Elem().Set(reflect.ValueOf(root)) + saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Save failed unexpectedly: %v", err) + } + + // Dump the serialized proto to aid with debugging. + var ppBuf bytes.Buffer + t.Logf("Raw state:\n%v", saveBuffer.Bytes()) + if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // We don't count this as a test failure if we + // have shouldFail set, but we will count as a + // failure if we were not expecting to fail. + if !shouldFail { + t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err) + } + } + if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // See above. + if !shouldFail { + t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err) + } + } + t.Logf("Encoded state:\n%s", ppBuf.String()) + t.Logf("Save stats:\n%s", saveStats.String()) + + // Load a new copy of the object. + loadObjectPtr := reflect.New(reflect.TypeOf(root)) + loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Load failed unexpectedly: %v", err) + } + + // Compare the values. + loadedValue := loadObjectPtr.Elem().Interface() + if !checkEqual(root, loadedValue) { + if shouldFail { + return + } + t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue) + } + + // Everything went okay. Is that good? + if shouldFail { + t.Fatalf("This test was expected to fail, but didn't.") + } + t.Logf("Load stats:\n%s", loadStats.String()) + + // Truncate half the bytes in the byte stream, + // and ensure that we can't restore. Then + // truncate only the final byte and ensure that + // we can't restore. + l := saveBuffer.Len() + halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2]) + if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with half bytes succeeded unexpectedly.") + } + missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1]) + if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with missing byte succeeded unexpectedly.") + } + }) + } +} + +// convert converts the slice to an []interface{}. +func convert(v interface{}) (r []interface{}) { + s := reflect.ValueOf(v) // Must be slice. + for i := 0; i < s.Len(); i++ { + r = append(r, s.Index(i).Interface()) + } + return r +} + +// flatten flattens multiple slices. +func flatten(vs ...interface{}) (r []interface{}) { + for _, v := range vs { + r = append(r, convert(v)...) + } + return r +} + +// filter maps from one slice to another. +func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) { + s := reflect.ValueOf(vs) + for i := 0; i < s.Len(); i++ { + v, ok := fn(s.Index(i).Interface()) + if ok { + r = append(r, v) + } + } + return r +} + +// combine combines objects in two slices as specified. +func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) { + s1 := reflect.ValueOf(v1) + s2 := reflect.ValueOf(v2) + for i := 0; i < s1.Len(); i++ { + for j := 0; j < s2.Len(); j++ { + // Combine using the given function. + r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface())) + } + } + return r +} + +// pointersTo is a filter function that returns pointers. +func pointersTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)) + v.Elem().Set(reflect.ValueOf(o)) + return v.Interface(), true + }) +} + +// interfacesTo is a filter function that returns interface objects. +func interfacesTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + var v [1]interface{} + v[0] = o + return v, true + }) +} diff --git a/pkg/state/types.go b/pkg/state/types.go new file mode 100644 index 000000000..215ef80f8 --- /dev/null +++ b/pkg/state/types.go @@ -0,0 +1,361 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "sort" + + "gvisor.dev/gvisor/pkg/state/wire" +) + +// assertValidType asserts that the type is valid. +func assertValidType(name string, fields []string) { + if name == "" { + Failf("type has empty name") + } + fieldsCopy := make([]string, len(fields)) + for i := 0; i < len(fields); i++ { + if fields[i] == "" { + Failf("field has empty name for type %q", name) + } + fieldsCopy[i] = fields[i] + } + sort.Slice(fieldsCopy, func(i, j int) bool { + return fieldsCopy[i] < fieldsCopy[j] + }) + for i := range fieldsCopy { + if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] { + Failf("duplicate field %q for type %s", fieldsCopy[i], name) + } + } +} + +// typeEntry is an entry in the typeDatabase. +type typeEntry struct { + ID typeID + wire.Type +} + +// reconciledTypeEntry is a reconciled entry in the typeDatabase. +type reconciledTypeEntry struct { + wire.Type + LocalType reflect.Type + FieldOrder []int +} + +// typeEncodeDatabase is an internal TypeInfo database for encoding. +type typeEncodeDatabase struct { + // byType maps by type to the typeEntry. + byType map[reflect.Type]*typeEntry + + // lastID is the last used ID. + lastID typeID +} + +// makeTypeEncodeDatabase makes a typeDatabase. +func makeTypeEncodeDatabase() typeEncodeDatabase { + return typeEncodeDatabase{ + byType: make(map[reflect.Type]*typeEntry), + } +} + +// typeDecodeDatabase is an internal TypeInfo database for decoding. +type typeDecodeDatabase struct { + // byID maps by ID to type. + byID []*reconciledTypeEntry + + // pending are entries that are pending validation by Lookup. These + // will be reconciled with actual objects. Note that these will also be + // used to lookup types by name, since they may not be reconciled and + // there's little value to deleting from this map. + pending []*wire.Type +} + +// makeTypeDecodeDatabase makes a typeDatabase. +func makeTypeDecodeDatabase() typeDecodeDatabase { + return typeDecodeDatabase{} +} + +// lookupNameFields extracts the name and fields from an object. +func lookupNameFields(typ reflect.Type) (string, []string, bool) { + v := reflect.Zero(reflect.PtrTo(typ)).Interface() + t, ok := v.(Type) + if !ok { + // Is this a primitive? + if typ.Kind() == reflect.Interface { + return interfaceType, nil, true + } + name := typ.Name() + if _, ok := primitiveTypeDatabase[name]; !ok { + // This is not a known type, and not a primitive. The + // encoder may proceed for anonymous empty structs, or + // it may deference the type pointer and try again. + return "", nil, false + } + return name, nil, true + } + // Extract the name from the object. + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + return name, fields, true +} + +// Lookup looks up or registers the given object. +// +// The bool indicates whether this is an existing entry: false means the entry +// did not exist, and true means the entry did exist. If this bool is false and +// the returned typeEntry are nil, then the obj did not implement the Type +// interface. +func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) { + te, ok := tdb.byType[typ] + if !ok { + // Lookup the type information. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs may still be encoded, so let the + // caller decide what to do from here. + return nil, false + } + + // Register the new type. + tdb.lastID++ + te = &typeEntry{ + ID: tdb.lastID, + Type: wire.Type{ + Name: name, + Fields: fields, + }, + } + + // All done. + tdb.byType[typ] = te + return te, false + } + return te, true +} + +// Register adds a typeID entry. +func (tbd *typeDecodeDatabase) Register(typ *wire.Type) { + assertValidType(typ.Name, typ.Fields) + tbd.pending = append(tbd.pending, typ) +} + +// LookupName looks up the type name by ID. +func (tbd *typeDecodeDatabase) LookupName(id typeID) string { + if len(tbd.pending) < int(id) { + // This is likely an encoder error? + Failf("type ID %d not available", id) + } + return tbd.pending[id-1].Name +} + +// LookupType looks up the type by ID. +func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type { + name := tbd.LookupName(id) + typ, ok := globalTypeDatabase[name] + if !ok { + // If not available, see if it's primitive. + typ, ok = primitiveTypeDatabase[name] + if !ok && name == interfaceType { + // Matches the built-in interface type. + var i interface{} + return reflect.TypeOf(&i).Elem() + } + if !ok { + // The type is perhaps not registered? + Failf("type name %q is not available", name) + } + return typ // Primitive type. + } + return typ // Registered type. +} + +// singleFieldOrder defines the field order for a single field. +var singleFieldOrder = []int{0} + +// Lookup looks up or registers the given object. +// +// First, the typeID is searched to see if this has already been appropriately +// reconciled. If no, then a reconcilation will take place that may result in a +// field ordering. If a nil reconciledTypeEntry is returned from this method, +// then the object does not support the Type interface. +// +// This method never returns nil. +func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry { + if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil { + // Already reconciled. + return tbd.byID[id-1] + } + // The ID has not been reconciled yet. That's fine. We need to make + // sure it aligns with the current provided object. + if len(tbd.pending) < int(id) { + // This id was never registered. Probably an encoder error? + Failf("typeDatabase does not contain id %d", id) + } + // Extract the pending info. + pending := tbd.pending[id-1] + // Grow the byID list. + if len(tbd.byID) < int(id) { + tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...) + } + // Reconcile the type. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs are decoded only when the type is nil. Since + // this isn't the case, we fail here. + Failf("unsupported type %q during decode; can't reconcile", pending.Name) + } + if name != pending.Name { + // Are these the same type? Print a helpful message as this may + // actually happen in practice if types change. + Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)", + id, name, fields, pending.Name, pending.Fields) + } + rte := &reconciledTypeEntry{ + Type: wire.Type{ + Name: name, + Fields: fields, + }, + LocalType: typ, + } + // If there are zero or one fields, then we skip allocating the field + // slice. There is special handling for decoding in this case. If the + // field name does not match, it will be caught in the general purpose + // code below. + if len(fields) != len(pending.Fields) { + Failf("type %q contains different fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + if len(fields) == 0 { + tbd.byID[id-1] = rte // Save. + return rte + } + if len(fields) == 1 && fields[0] == pending.Fields[0] { + tbd.byID[id-1] = rte // Save. + rte.FieldOrder = singleFieldOrder + return rte + } + // For each field in the current object's information, match it to a + // field in the destination object. We know from the assertion above + // and the insertion on insertion to pending that neither field + // contains any duplicates. + fieldOrder := make([]int, len(fields)) + for i, name := range fields { + fieldOrder[i] = -1 // Sentinel. + // Is it an exact match? + if pending.Fields[i] == name { + fieldOrder[i] = i + continue + } + // Find the matching field. + for j, otherName := range pending.Fields { + if name == otherName { + fieldOrder[i] = j + break + } + } + if fieldOrder[i] == -1 { + // The type name matches but we are lacking some common fields. + Failf("type %q has mismatched fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + } + // The type has been reeconciled. + rte.FieldOrder = fieldOrder + tbd.byID[id-1] = rte + return rte +} + +// interfaceType defines all interfaces. +const interfaceType = "interface" + +// primitiveTypeDatabase is a set of fixed types. +var primitiveTypeDatabase = func() map[string]reflect.Type { + r := make(map[string]reflect.Type) + for _, t := range []reflect.Type{ + reflect.TypeOf(false), + reflect.TypeOf(int(0)), + reflect.TypeOf(int8(0)), + reflect.TypeOf(int16(0)), + reflect.TypeOf(int32(0)), + reflect.TypeOf(int64(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf(uintptr(0)), + reflect.TypeOf(uint8(0)), + reflect.TypeOf(uint16(0)), + reflect.TypeOf(uint32(0)), + reflect.TypeOf(uint64(0)), + reflect.TypeOf(""), + reflect.TypeOf(float32(0.0)), + reflect.TypeOf(float64(0.0)), + reflect.TypeOf(complex64(0.0)), + reflect.TypeOf(complex128(0.0)), + } { + r[t.Name()] = t + } + return r +}() + +// globalTypeDatabase is used for dispatching interfaces on decode. +var globalTypeDatabase = map[string]reflect.Type{} + +// Register registers a type. +// +// This must be called on init and only done once. +func Register(t Type) { + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + // Register must always be called on pointers. + typ := reflect.TypeOf(t) + if typ.Kind() != reflect.Ptr { + Failf("Register must be called on pointers") + } + typ = typ.Elem() + if typ.Kind() == reflect.Struct { + // All registered structs must implement SaverLoader. We allow + // the registration is non-struct types with just the Type + // interface, but we need to call StateSave/StateLoad methods + // on aggregate types. + if _, ok := t.(SaverLoader); !ok { + Failf("struct %T does not implement SaverLoader", t) + } + } else { + // Non-structs must not have any fields. We don't support + // calling StateSave/StateLoad methods on any non-struct types. + // If custom behavior is required, these types should be + // wrapped in a structure of some kind. + if len(fields) != 0 { + Failf("non-struct %T has non-zero fields %v", t, fields) + } + // We don't allow non-structs to implement StateSave/StateLoad + // methods, because they won't be called and it's confusing. + if _, ok := t.(SaverLoader); ok { + Failf("non-struct %T implements SaverLoader", t) + } + } + if _, ok := primitiveTypeDatabase[name]; ok { + Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t) + } + if _, ok := globalTypeDatabase[name]; ok { + Failf("conflicting globalTypeDatabase entries for %T: name conflict", t) + } + if name == interfaceType { + Failf("conflicting name for %T: matches interfaceType", t) + } + globalTypeDatabase[name] = typ +} diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD new file mode 100644 index 000000000..311b93dcb --- /dev/null +++ b/pkg/state/wire/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "wire", + srcs = ["wire.go"], + marshal = False, + stateify = False, + visibility = ["//:sandbox"], + deps = ["//pkg/gohacks"], +) diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go new file mode 100644 index 000000000..93dee6740 --- /dev/null +++ b/pkg/state/wire/wire.go @@ -0,0 +1,970 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package wire contains a few basic types that can be composed to serialize +// graph information for the state package. This package defines the wire +// protocol. +// +// Note that these types are careful about how they implement the relevant +// interfaces (either value receiver or pointer receiver), so that native-sized +// types, such as integers and simple pointers, can fit inside the interface +// object. +// +// This package also uses panic as control flow, so called should be careful to +// wrap calls in appropriate handlers. +// +// Testing for this package is driven by the state test package. +package wire + +import ( + "fmt" + "io" + "math" + + "gvisor.dev/gvisor/pkg/gohacks" +) + +// Reader is the required reader interface. +type Reader interface { + io.Reader + ReadByte() (byte, error) +} + +// Writer is the required writer interface. +type Writer interface { + io.Writer + WriteByte(byte) error +} + +// readFull is a utility. The equivalent is not needed for Write, but the API +// contract dictates that it must always complete all bytes given or return an +// error. +func readFull(r io.Reader, p []byte) { + for done := 0; done < len(p); { + n, err := r.Read(p[done:]) + done += n + if n == 0 && err != nil { + panic(err) + } + } +} + +// Object is a generic object. +type Object interface { + // save saves the given object. + // + // Panic is used for error control flow. + save(Writer) + + // load loads a new object of the given type. + // + // Panic is used for error control flow. + load(Reader) Object +} + +// Bool is a boolean. +type Bool bool + +// loadBool loads an object of type Bool. +func loadBool(r Reader) Bool { + b := loadUint(r) + return Bool(b == 1) +} + +// save implements Object.save. +func (b Bool) save(w Writer) { + var v Uint + if b { + v = 1 + } else { + v = 0 + } + v.save(w) +} + +// load implements Object.load. +func (Bool) load(r Reader) Object { return loadBool(r) } + +// Int is a signed integer. +// +// This uses varint encoding. +type Int int64 + +// loadInt loads an object of type Int. +func loadInt(r Reader) Int { + u := loadUint(r) + x := Int(u >> 1) + if u&1 != 0 { + x = ^x + } + return x +} + +// save implements Object.save. +func (i Int) save(w Writer) { + u := Uint(i) << 1 + if i < 0 { + u = ^u + } + u.save(w) +} + +// load implements Object.load. +func (Int) load(r Reader) Object { return loadInt(r) } + +// Uint is an unsigned integer. +type Uint uint64 + +// loadUint loads an object of type Uint. +func loadUint(r Reader) Uint { + var ( + u Uint + s uint + ) + for i := 0; i <= 9; i++ { + b, err := r.ReadByte() + if err != nil { + panic(err) + } + if b < 0x80 { + if i == 9 && b > 1 { + panic("overflow") + } + u |= Uint(b) << s + return u + } + u |= Uint(b&0x7f) << s + s += 7 + } + panic("unreachable") +} + +// save implements Object.save. +func (u Uint) save(w Writer) { + for u >= 0x80 { + if err := w.WriteByte(byte(u) | 0x80); err != nil { + panic(err) + } + u >>= 7 + } + if err := w.WriteByte(byte(u)); err != nil { + panic(err) + } +} + +// load implements Object.load. +func (Uint) load(r Reader) Object { return loadUint(r) } + +// Float32 is a 32-bit floating point number. +type Float32 float32 + +// loadFloat32 loads an object of type Float32. +func loadFloat32(r Reader) Float32 { + n := loadUint(r) + return Float32(math.Float32frombits(uint32(n))) +} + +// save implements Object.save. +func (f Float32) save(w Writer) { + n := Uint(math.Float32bits(float32(f))) + n.save(w) +} + +// load implements Object.load. +func (Float32) load(r Reader) Object { return loadFloat32(r) } + +// Float64 is a 64-bit floating point number. +type Float64 float64 + +// loadFloat64 loads an object of type Float64. +func loadFloat64(r Reader) Float64 { + n := loadUint(r) + return Float64(math.Float64frombits(uint64(n))) +} + +// save implements Object.save. +func (f Float64) save(w Writer) { + n := Uint(math.Float64bits(float64(f))) + n.save(w) +} + +// load implements Object.load. +func (Float64) load(r Reader) Object { return loadFloat64(r) } + +// Complex64 is a 64-bit complex number. +type Complex64 complex128 + +// loadComplex64 loads an object of type Complex64. +func loadComplex64(r Reader) Complex64 { + re := loadFloat32(r) + im := loadFloat32(r) + return Complex64(complex(float32(re), float32(im))) +} + +// save implements Object.save. +func (c *Complex64) save(w Writer) { + re := Float32(real(*c)) + im := Float32(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex64) load(r Reader) Object { + c := loadComplex64(r) + return &c +} + +// Complex128 is a 128-bit complex number. +type Complex128 complex128 + +// loadComplex128 loads an object of type Complex128. +func loadComplex128(r Reader) Complex128 { + re := loadFloat64(r) + im := loadFloat64(r) + return Complex128(complex(float64(re), float64(im))) +} + +// save implements Object.save. +func (c *Complex128) save(w Writer) { + re := Float64(real(*c)) + im := Float64(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex128) load(r Reader) Object { + c := loadComplex128(r) + return &c +} + +// String is a string. +type String string + +// loadString loads an object of type String. +func loadString(r Reader) String { + l := loadUint(r) + p := make([]byte, l) + readFull(r, p) + return String(gohacks.StringFromImmutableBytes(p)) +} + +// save implements Object.save. +func (s *String) save(w Writer) { + l := Uint(len(*s)) + l.save(w) + p := gohacks.ImmutableBytesFromString(string(*s)) + _, err := w.Write(p) // Must write all bytes. + if err != nil { + panic(err) + } +} + +// load implements Object.load. +func (*String) load(r Reader) Object { + s := loadString(r) + return &s +} + +// Dot is a kind of reference: one of Index and FieldName. +type Dot interface { + isDot() +} + +// Index is a reference resolution. +type Index uint32 + +func (Index) isDot() {} + +// FieldName is a reference resolution. +type FieldName string + +func (*FieldName) isDot() {} + +// Ref is a reference to an object. +type Ref struct { + // Root is the root object. + Root Uint + + // Dots is the set of traversals required from the Root object above. + // Note that this will be stored in reverse order for efficiency. + Dots []Dot + + // Type is the base type for the root object. This is non-nil iff Dots + // is non-zero length (that is, this is a complex reference). This is + // not *strictly* necessary, but can be used to simplify decoding. + Type TypeSpec +} + +// loadRef loads an object of type Ref (abstract). +func loadRef(r Reader) Ref { + ref := Ref{ + Root: loadUint(r), + } + l := loadUint(r) + ref.Dots = make([]Dot, l) + for i := 0; i < int(l); i++ { + // Disambiguate between an Index (non-negative) and a field + // name (negative). This does some space and avoids a dedicate + // loadDot function. See Ref.save for the other side. + d := loadInt(r) + if d >= 0 { + ref.Dots[i] = Index(d) + continue + } + p := make([]byte, -d) + readFull(r, p) + fieldName := FieldName(gohacks.StringFromImmutableBytes(p)) + ref.Dots[i] = &fieldName + } + if l != 0 { + // Only if dots is non-zero. + ref.Type = loadTypeSpec(r) + } + return ref +} + +// save implements Object.save. +func (r *Ref) save(w Writer) { + r.Root.save(w) + l := Uint(len(r.Dots)) + l.save(w) + for _, d := range r.Dots { + // See LoadRef. We use non-negative numbers to encode Index + // objects and negative numbers to encode field lengths. + switch x := d.(type) { + case Index: + i := Int(x) + i.save(w) + case *FieldName: + d := Int(-len(*x)) + d.save(w) + p := gohacks.ImmutableBytesFromString(string(*x)) + if _, err := w.Write(p); err != nil { + panic(err) + } + default: + panic("unknown dot implementation") + } + } + if l != 0 { + // See above. + saveTypeSpec(w, r.Type) + } +} + +// load implements Object.load. +func (*Ref) load(r Reader) Object { + ref := loadRef(r) + return &ref +} + +// Nil is a primitive zero value of any type. +type Nil struct{} + +// loadNil loads an object of type Nil. +func loadNil(r Reader) Nil { + return Nil{} +} + +// save implements Object.save. +func (Nil) save(w Writer) {} + +// load implements Object.load. +func (Nil) load(r Reader) Object { return loadNil(r) } + +// Slice is a slice value. +type Slice struct { + Length Uint + Capacity Uint + Ref Ref +} + +// loadSlice loads an object of type Slice. +func loadSlice(r Reader) Slice { + return Slice{ + Length: loadUint(r), + Capacity: loadUint(r), + Ref: loadRef(r), + } +} + +// save implements Object.save. +func (s *Slice) save(w Writer) { + s.Length.save(w) + s.Capacity.save(w) + s.Ref.save(w) +} + +// load implements Object.load. +func (*Slice) load(r Reader) Object { + s := loadSlice(r) + return &s +} + +// Array is an array value. +type Array struct { + Contents []Object +} + +// loadArray loads an object of type Array. +func loadArray(r Reader) Array { + l := loadUint(r) + if l == 0 { + // Note that there isn't a single object available to encode + // the type of, so we need this additional branch. + return Array{} + } + // All the objects here have the same type, so use dynamic dispatch + // only once. All other objects will automatically take the same type + // as the first object. + contents := make([]Object, l) + v := Load(r) + contents[0] = v + for i := 1; i < int(l); i++ { + contents[i] = v.load(r) + } + return Array{ + Contents: contents, + } +} + +// save implements Object.save. +func (a *Array) save(w Writer) { + l := Uint(len(a.Contents)) + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, a.Contents[0]) + for i := 1; i < int(l); i++ { + a.Contents[i].save(w) + } +} + +// load implements Object.load. +func (*Array) load(r Reader) Object { + a := loadArray(r) + return &a +} + +// Map is a map value. +type Map struct { + Keys []Object + Values []Object +} + +// loadMap loads an object of type Map. +func loadMap(r Reader) Map { + l := loadUint(r) + if l == 0 { + // See LoadArray. + return Map{} + } + // See type dispatch notes in Array. + keys := make([]Object, l) + values := make([]Object, l) + k := Load(r) + v := Load(r) + keys[0] = k + values[0] = v + for i := 1; i < int(l); i++ { + keys[i] = k.load(r) + values[i] = v.load(r) + } + return Map{ + Keys: keys, + Values: values, + } +} + +// save implements Object.save. +func (m *Map) save(w Writer) { + l := Uint(len(m.Keys)) + if int(l) != len(m.Values) { + panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values))) + } + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, m.Keys[0]) + Save(w, m.Values[0]) + for i := 1; i < int(l); i++ { + m.Keys[i].save(w) + m.Values[i].save(w) + } +} + +// load implements Object.load. +func (*Map) load(r Reader) Object { + m := loadMap(r) + return &m +} + +// TypeSpec is a type dereference. +type TypeSpec interface { + isTypeSpec() +} + +// TypeID is a concrete type ID. +type TypeID Uint + +func (TypeID) isTypeSpec() {} + +// TypeSpecPointer is a pointer type. +type TypeSpecPointer struct { + Type TypeSpec +} + +func (*TypeSpecPointer) isTypeSpec() {} + +// TypeSpecArray is an array type. +type TypeSpecArray struct { + Count Uint + Type TypeSpec +} + +func (*TypeSpecArray) isTypeSpec() {} + +// TypeSpecSlice is a slice type. +type TypeSpecSlice struct { + Type TypeSpec +} + +func (*TypeSpecSlice) isTypeSpec() {} + +// TypeSpecMap is a map type. +type TypeSpecMap struct { + Key TypeSpec + Value TypeSpec +} + +func (*TypeSpecMap) isTypeSpec() {} + +// TypeSpecNil is an empty type. +type TypeSpecNil struct{} + +func (TypeSpecNil) isTypeSpec() {} + +// TypeSpec types. +// +// These use a distinct encoding on the wire, as they are used only in the +// interface object. They are decoded through the dedicated loadTypeSpec and +// saveTypeSpec functions. +const ( + typeSpecTypeID Uint = iota + typeSpecPointer + typeSpecArray + typeSpecSlice + typeSpecMap + typeSpecNil +) + +// loadTypeSpec loads TypeSpec values. +func loadTypeSpec(r Reader) TypeSpec { + switch hdr := loadUint(r); hdr { + case typeSpecTypeID: + return TypeID(loadUint(r)) + case typeSpecPointer: + return &TypeSpecPointer{ + Type: loadTypeSpec(r), + } + case typeSpecArray: + return &TypeSpecArray{ + Count: loadUint(r), + Type: loadTypeSpec(r), + } + case typeSpecSlice: + return &TypeSpecSlice{ + Type: loadTypeSpec(r), + } + case typeSpecMap: + return &TypeSpecMap{ + Key: loadTypeSpec(r), + Value: loadTypeSpec(r), + } + case typeSpecNil: + return TypeSpecNil{} + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// saveTypeSpec saves TypeSpec values. +func saveTypeSpec(w Writer, t TypeSpec) { + switch x := t.(type) { + case TypeID: + typeSpecTypeID.save(w) + Uint(x).save(w) + case *TypeSpecPointer: + typeSpecPointer.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecArray: + typeSpecArray.save(w) + x.Count.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecSlice: + typeSpecSlice.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecMap: + typeSpecMap.save(w) + saveTypeSpec(w, x.Key) + saveTypeSpec(w, x.Value) + case TypeSpecNil: + typeSpecNil.save(w) + default: + // This should not happen? + panic(fmt.Errorf("unknown type %T", t)) + } +} + +// Interface is an interface value. +type Interface struct { + Type TypeSpec + Value Object +} + +// loadInterface loads an object of type Interface. +func loadInterface(r Reader) Interface { + return Interface{ + Type: loadTypeSpec(r), + Value: Load(r), + } +} + +// save implements Object.save. +func (i *Interface) save(w Writer) { + saveTypeSpec(w, i.Type) + Save(w, i.Value) +} + +// load implements Object.load. +func (*Interface) load(r Reader) Object { + i := loadInterface(r) + return &i +} + +// Type is type information. +type Type struct { + Name string + Fields []string +} + +// loadType loads an object of type Type. +func loadType(r Reader) Type { + name := string(loadString(r)) + l := loadUint(r) + fields := make([]string, l) + for i := 0; i < int(l); i++ { + fields[i] = string(loadString(r)) + } + return Type{ + Name: name, + Fields: fields, + } +} + +// save implements Object.save. +func (t *Type) save(w Writer) { + s := String(t.Name) + s.save(w) + l := Uint(len(t.Fields)) + l.save(w) + for i := 0; i < int(l); i++ { + s := String(t.Fields[i]) + s.save(w) + } +} + +// load implements Object.load. +func (*Type) load(r Reader) Object { + t := loadType(r) + return &t +} + +// multipleObjects is a special type for serializing multiple objects. +type multipleObjects []Object + +// loadMultipleObjects loads a series of objects. +func loadMultipleObjects(r Reader) multipleObjects { + l := loadUint(r) + m := make(multipleObjects, l) + for i := 0; i < int(l); i++ { + m[i] = Load(r) + } + return m +} + +// save implements Object.save. +func (m *multipleObjects) save(w Writer) { + l := Uint(len(*m)) + l.save(w) + for i := 0; i < int(l); i++ { + Save(w, (*m)[i]) + } +} + +// load implements Object.load. +func (*multipleObjects) load(r Reader) Object { + m := loadMultipleObjects(r) + return &m +} + +// noObjects represents no objects. +type noObjects struct{} + +// loadNoObjects loads a sentinel. +func loadNoObjects(r Reader) noObjects { return noObjects{} } + +// save implements Object.save. +func (noObjects) save(w Writer) {} + +// load implements Object.load. +func (noObjects) load(r Reader) Object { return loadNoObjects(r) } + +// Struct is a basic composite value. +type Struct struct { + TypeID TypeID + fields Object // Optionally noObjects or *multipleObjects. +} + +// Field returns a pointer to the given field slot. +// +// This must be called after Alloc. +func (s *Struct) Field(i int) *Object { + if fields, ok := s.fields.(*multipleObjects); ok { + return &((*fields)[i]) + } + if _, ok := s.fields.(noObjects); ok { + // Alloc may be optionally called; can't call twice. + panic("Field called inappropriately, wrong Alloc?") + } + return &s.fields +} + +// Alloc allocates the given number of fields. +// +// This must be called before Add and Save. +// +// Precondition: slots must be positive. +func (s *Struct) Alloc(slots int) { + switch { + case slots == 0: + s.fields = noObjects{} + case slots == 1: + // Leave it alone. + case slots > 1: + fields := make(multipleObjects, slots) + s.fields = &fields + default: + // Violates precondition. + panic(fmt.Sprintf("Alloc called with negative slots %d?", slots)) + } +} + +// Fields returns the number of fields. +func (s *Struct) Fields() int { + switch x := s.fields.(type) { + case *multipleObjects: + return len(*x) + case noObjects: + return 0 + default: + return 1 + } +} + +// loadStruct loads an object of type Struct. +func loadStruct(r Reader) Struct { + return Struct{ + TypeID: TypeID(loadUint(r)), + fields: Load(r), + } +} + +// save implements Object.save. +// +// Precondition: Alloc must have been called, and the fields all filled in +// appropriately. See Alloc and Add for more details. +func (s *Struct) save(w Writer) { + Uint(s.TypeID).save(w) + Save(w, s.fields) +} + +// load implements Object.load. +func (*Struct) load(r Reader) Object { + s := loadStruct(r) + return &s +} + +// Object types. +// +// N.B. Be careful about changing the order or introducing new elements in the +// middle here. This is part of the wire format and shouldn't change. +const ( + typeBool Uint = iota + typeInt + typeUint + typeFloat32 + typeFloat64 + typeNil + typeRef + typeString + typeSlice + typeArray + typeMap + typeStruct + typeNoObjects + typeMultipleObjects + typeInterface + typeComplex64 + typeComplex128 + typeType +) + +// Save saves the given object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Save(w Writer, obj Object) { + switch x := obj.(type) { + case Bool: + typeBool.save(w) + x.save(w) + case Int: + typeInt.save(w) + x.save(w) + case Uint: + typeUint.save(w) + x.save(w) + case Float32: + typeFloat32.save(w) + x.save(w) + case Float64: + typeFloat64.save(w) + x.save(w) + case Nil: + typeNil.save(w) + x.save(w) + case *Ref: + typeRef.save(w) + x.save(w) + case *String: + typeString.save(w) + x.save(w) + case *Slice: + typeSlice.save(w) + x.save(w) + case *Array: + typeArray.save(w) + x.save(w) + case *Map: + typeMap.save(w) + x.save(w) + case *Struct: + typeStruct.save(w) + x.save(w) + case noObjects: + typeNoObjects.save(w) + x.save(w) + case *multipleObjects: + typeMultipleObjects.save(w) + x.save(w) + case *Interface: + typeInterface.save(w) + x.save(w) + case *Type: + typeType.save(w) + x.save(w) + case *Complex64: + typeComplex64.save(w) + x.save(w) + case *Complex128: + typeComplex128.save(w) + x.save(w) + default: + panic(fmt.Errorf("unknown type: %#v", obj)) + } +} + +// Load loads a new object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Load(r Reader) Object { + switch hdr := loadUint(r); hdr { + case typeBool: + return loadBool(r) + case typeInt: + return loadInt(r) + case typeUint: + return loadUint(r) + case typeFloat32: + return loadFloat32(r) + case typeFloat64: + return loadFloat64(r) + case typeNil: + return loadNil(r) + case typeRef: + return ((*Ref)(nil)).load(r) // Escapes. + case typeString: + return ((*String)(nil)).load(r) // Escapes. + case typeSlice: + return ((*Slice)(nil)).load(r) // Escapes. + case typeArray: + return ((*Array)(nil)).load(r) // Escapes. + case typeMap: + return ((*Map)(nil)).load(r) // Escapes. + case typeStruct: + return ((*Struct)(nil)).load(r) // Escapes. + case typeNoObjects: // Special for struct. + return loadNoObjects(r) + case typeMultipleObjects: // Special for struct. + return ((*multipleObjects)(nil)).load(r) // Escapes. + case typeInterface: + return ((*Interface)(nil)).load(r) // Escapes. + case typeComplex64: + return ((*Complex64)(nil)).load(r) // Escapes. + case typeComplex128: + return ((*Complex128)(nil)).load(r) // Escapes. + case typeType: + return ((*Type)(nil)).load(r) // Escapes. + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// LoadUint loads a single unsigned integer. +// +// N.B. This function will panic on error. +func LoadUint(r Reader) uint64 { + return uint64(loadUint(r)) +} + +// SaveUint saves a single unsigned integer. +// +// N.B. This function will panic on error. +func SaveUint(w Writer, v uint64) { + Uint(v).save(w) +} diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index d0d77e19c..4d47207f7 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -33,6 +33,7 @@ go_library( "aliases.go", "memmove_unsafe.go", "mutex_unsafe.go", + "nocopy.go", "norace_unsafe.go", "race_unsafe.go", "rwmutex_unsafe.go", diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go new file mode 100644 index 000000000..722b29501 --- /dev/null +++ b/pkg/sync/nocopy.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// NoCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type NoCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Unlock() {} diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 8ff922c69..5ae10939d 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -22,7 +22,7 @@ import ( // Mapping for tcpip.Error types. var ( ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL) + ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c1745ba6a..ee264b726 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -320,6 +320,22 @@ func DstPort(port uint16) TransportChecker { } } +// NoChecksum creates a checker that checks if the checksum is zero. +func NoChecksum(noChecksum bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + udp, ok := h.(header.UDP) + if !ok { + return + } + + if b := udp.Checksum() == 0; b != noChecksum { + t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) + } + } +} + // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0cde694dc..d87797617 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -48,7 +48,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -64,6 +64,6 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index b1e92d2d7..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -53,6 +53,10 @@ const ( // (all bits set to 0). unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + // unicastMulticastFlagMask is the mask of the least significant bit in // the first octet (in network byte order) of an ethernet address that // determines whether the ethernet address is a unicast or multicast. If diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 7908c5744..1a631b31a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -72,6 +72,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( ICMPv4TTLExceeded = 0 + ICMPv4HostUnreachable = 1 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c7ee2de57..a13b4b809 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -110,9 +110,16 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// Values for ICMP destination unreachable code as defined in RFC 4443 section +// 3.1. const ( - ICMPv6PortUnreachable = 4 + ICMPv6NetworkUnreachable = 0 + ICMPv6Prohibited = 1 + ICMPv6BeyondScope = 2 + ICMPv6AddressUnreachable = 3 + ICMPv6PortUnreachable = 4 + ICMPv6Policy = 5 + ICMPv6RejectRoute = 6 ) // Type is the ICMP type field. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..e12a5929b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -296,3 +297,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index aa6db9aea..507b44abc 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/binary", + "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..c18bb91fb 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,6 +45,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -385,26 +386,35 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -430,47 +440,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + builder.Add(vnetHdrBuf) } - if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Header.View()) - } - if pkt.Header.UsedLength() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) } - return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var ethHdrBuf []byte - iovLen := 0 if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } - vnetHdr := virtioNetHdr{} var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if pkt.GSOOptions.NeedsCsum { @@ -491,45 +482,19 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc } } vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - iovLen++ } - iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + var builder iovec.Builder + builder.Add(vnetHdrBuf) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) + } + iovecs := builder.Build() + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] - iovecIdx := 0 - if vnetHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = &vnetHdrBuf[0] - v.Len = uint64(len(vnetHdrBuf)) - iovecIdx++ - } - if ethHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = ðHdrBuf[0] - v.Len = uint64(len(ethHdrBuf)) - iovecIdx++ - } - pktSize := uint64(0) - // Encode L3 Header - v := &iovecs[iovecIdx] - hdr := &pkt.Header - hdrView := hdr.View() - v.Base = &hdrView[0] - v.Len = uint64(len(hdrView)) - pktSize += v.Len - iovecIdx++ - - // Now encode the Transport Payload. - pktViews := pkt.Data.Views() - for i := range pktViews { - vec := &iovecs[iovecIdx] - iovecIdx++ - vec.Base = &pktViews[i][0] - vec.Len = uint64(len(pktViews[i])) - pktSize += vec.Len - } - mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + mmsgHdr.Msg.Iovlen = uint64(len(iovecs)) mmsgHdrs = append(mmsgHdrs, mmsgHdr) } @@ -626,6 +591,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index eaee7e5d7..7b995b85a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -500,3 +504,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { } } + +// fakeNetworkDispatcher delivers packets to pkts. +type fakeNetworkDispatcher struct { + pkts []*stack.PacketBuffer +} + +func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + d.pkts = append(d.pkts, pkt) +} + +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestDispatchPacketFormat(t *testing.T) { + for _, test := range []struct { + name string + newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) + }{ + { + name: "readVDispatcher", + newDispatcher: newReadVDispatcher, + }, + { + name: "recvMMsgDispatcher", + newDispatcher: newRecvMMsgDispatcher, + }, + } { + t.Run(test.name, func(t *testing.T) { + // Create a socket pair to send/recv. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + data := []byte{ + // Ethernet header. + 1, 2, 3, 4, 5, 60, + 1, 2, 3, 4, 5, 61, + 8, 0, + // Mock network header. + 40, 41, 42, 43, + } + err = syscall.Sendmsg(fds[1], data, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + + // Create and run dispatcher once. + sink := &fakeNetworkDispatcher{} + d, err := test.newDispatcher(fds[0], &endpoint{ + hdrSize: header.EthernetMinimumSize, + dispatcher: sink, + }) + if err != nil { + t.Fatal(err) + } + if ok, err := d.dispatch(); !ok || err != nil { + t.Fatalf("d.dispatch() = %v, %v", ok, err) + } + + // Verify packet. + if got, want := len(sink.pkts), 1; got != want { + t.Fatalf("len(sink.pkts) = %d, want %d", got, want) + } + pkt := sink.pkts[0] + if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { + t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + } + if got, want := pkt.Data.Size(), 4; got != want { + t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index f04738cfb..d8f2504b3 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -278,7 +278,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..781cdd317 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -113,3 +113,11 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD new file mode 100644 index 000000000..2cdb23475 --- /dev/null +++ b/pkg/tcpip/link/nested/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "nested", + srcs = [ + "nested.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "nested_test", + size = "small", + srcs = [ + "nested_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go new file mode 100644 index 000000000..d40de54df --- /dev/null +++ b/pkg/tcpip/link/nested/nested.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nested provides helpers to implement the pattern of nested +// stack.LinkEndpoints. +package nested + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher +// that can be used to implement nesting safely by providing lifecycle +// concurrency guards. +// +// See the tests in this package for example usage. +type Endpoint struct { + child stack.LinkEndpoint + embedder stack.NetworkDispatcher + + // mu protects dispatcher. + mu sync.RWMutex + dispatcher stack.NetworkDispatcher +} + +var _ stack.GSOEndpoint = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.NetworkDispatcher = (*Endpoint)(nil) + +// Init initializes a nested.Endpoint that uses embedder as the dispatcher for +// child on Attach. +// +// See the tests in this package for example usage. +func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) { + e.child = child + e.embedder = embedder +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(remote, local, protocol, pkt) + } +} + +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + e.dispatcher = dispatcher + e.mu.Unlock() + // If we're attaching to a valid dispatcher, pass embedder as the dispatcher + // to our child, otherwise detach the child by giving it a nil dispatcher. + var pass stack.NetworkDispatcher + if dispatcher != nil { + pass = e.embedder + } + e.child.Attach(pass) +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + e.mu.RLock() + isAttached := e.dispatcher != nil + e.mu.RUnlock() + return isAttached +} + +// MTU implements stack.LinkEndpoint. +func (e *Endpoint) MTU() uint32 { + return e.child.MTU() +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.child.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.child.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.child.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + return e.child.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + return e.child.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.child.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint. +func (e *Endpoint) Wait() { + e.child.Wait() +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *Endpoint) GSOMaxSize() uint32 { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.GSOMaxSize() + } + return 0 +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go new file mode 100644 index 000000000..7d9249c1c --- /dev/null +++ b/pkg/tcpip/link/nested/nested_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nested_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type parentEndpoint struct { + nested.Endpoint +} + +var _ stack.LinkEndpoint = (*parentEndpoint)(nil) +var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) + +type childEndpoint struct { + stack.LinkEndpoint + dispatcher stack.NetworkDispatcher +} + +var _ stack.LinkEndpoint = (*childEndpoint)(nil) + +func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + c.dispatcher = dispatcher +} + +func (c *childEndpoint) IsAttached() bool { + return c.dispatcher != nil +} + +type counterDispatcher struct { + count int +} + +var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) + +func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + d.count++ +} + +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestNestedLinkEndpoint(t *testing.T) { + const emptyAddress = tcpip.LinkAddress("") + + var ( + childEP childEndpoint + nestedEP parentEndpoint + disp counterDispatcher + ) + nestedEP.Endpoint.Init(&childEP, &nestedEP) + + if childEP.IsAttached() { + t.Error("On init, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("On init, nestedEP.IsAttached() = true, want = false") + } + + nestedEP.Attach(&disp) + if disp.count != 0 { + t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) + } + if !childEP.IsAttached() { + t.Error("After attach, childEP.IsAttached() = false, want = true") + } + if !nestedEP.IsAttached() { + t.Error("After attach, nestedEP.IsAttached() = false, want = true") + } + + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 1 { + t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) + } + + nestedEP.Attach(nil) + if childEP.IsAttached() { + t.Error("After detach, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("After detach, nestedEP.IsAttached() = true, want = false") + } + + disp.count = 0 + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 0 { + t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) + } + +} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..467083239 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -207,3 +215,13 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..f4c32c2da 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a -// single syscall. It fails if partial data is written. -func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If the is no second buffer, issue a regular write. - if len(b2) == 0 { - return NonBlockingWrite(fd, b1) - } - - // We have two buffers. Build the iovec that represents them and issue - // a writev syscall. - iovec := [3]syscall.Iovec{ - { - Base: &b1[0], - Len: uint64(len(b1)), - }, - { - Base: &b2[0], - Len: uint64(len(b2)), - }, - } - iovecLen := uintptr(2) - - if len(b3) > 0 { - iovecLen++ - iovec[2].Base = &b3[0] - iovec[2].Len = uint64(len(b3)) - } - +// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { + iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } - return nil } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..507c76b76 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) v := pkt.Data.ToView() // Transmit the packet. @@ -287,3 +294,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..8f3cd9449 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 230a8d53a..7cbc305e7 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index f2e47b6a7..509076643 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -48,18 +49,21 @@ var LogPackets uint32 = 1 var LogPacketsToPCAP uint32 = 1 type endpoint struct { - dispatcher stack.NetworkDispatcher - lower stack.LinkEndpoint + nested.Endpoint writer io.Writer maxPCAPLen uint32 } +var _ stack.GSOEndpoint = (*endpoint)(nil) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.NetworkDispatcher = (*endpoint)(nil) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - return &endpoint{ - lower: lower, - } + sniffer := &endpoint{} + sniffer.Endpoint.Init(lower, sniffer) + return sniffer } func zoneOffset() (int32, error) { @@ -103,11 +107,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( if err := writePCAPHeader(writer, snapLen); err != nil { return nil, err } - return &endpoint{ - lower: lower, + sniffer := &endpoint{ writer: writer, maxPCAPLen: snapLen, - }, nil + } + sniffer.Endpoint.Init(lower, sniffer) + return sniffer, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -115,50 +120,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dumpPacket("recv", nil, protocol, pkt) - e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) -} - -// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher -// and registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *endpoint) MTU() uint32 { - return e.lower.MTU() + e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards -// the request to the lower endpoint. -func (e *endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.lower.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { @@ -203,7 +170,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.dumpPacket("send", gso, protocol, pkt) - return e.lower.WritePacket(r, gso, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by @@ -213,7 +180,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket("send", gso, protocol, pkt) } - return e.lower.WritePackets(r, gso, pkts, protocol) + return e.Endpoint.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. @@ -221,12 +188,9 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ Data: vv, }) - return e.lower.WriteRawPacket(vv) + return e.Endpoint.WriteRawPacket(vv) } -// Wait implements stack.LinkEndpoint.Wait. -func (e *endpoint) Wait() { e.lower.Wait() } - func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..04ae58e59 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -271,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader) } // Append upper headers. @@ -348,6 +337,7 @@ type tunEndpoint struct { stack *stack.Stack nicID tcpip.NICID name string + isTap bool } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. @@ -356,3 +346,38 @@ func (e *tunEndpoint) DecRef() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..c448a888f 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -81,9 +86,19 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7f27a840d..b0f57040c 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -162,7 +162,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: header.EthernetBroadcastAddress, } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) @@ -181,7 +181,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv4Address(addr), true @@ -216,8 +216,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return 0, false, true } -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - // NewProtocol returns an ARP network protocol. func NewProtocol() stack.NetworkProtocol { return &protocol{} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..615bae648 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,14 +172,24 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 78420d6e6..d142b4ffa 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -34,6 +34,6 @@ go_test( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1b67aa066..83e71cb8c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7e9f16c90..b1776e5ee 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -225,12 +225,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -376,13 +374,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 3f71fc520..feada63dc 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -39,6 +39,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2ff7eedf4..ff1cb53dd 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } @@ -494,8 +496,6 @@ const ( icmpV6LengthOffset = 25 ) -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index edc29ad27..f6d592eb5 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -52,6 +52,9 @@ type Flags struct { // // LoadBalanced takes precidence over MostRecent. LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool } // Bits converts the Flags to their bitset form. @@ -63,6 +66,9 @@ func (f Flags) Bits() BitFlags { if f.LoadBalanced { rf |= LoadBalancedFlag } + if f.TupleOnly { + rf |= TupleOnlyFlag + } return rf } @@ -98,6 +104,9 @@ const ( // LoadBalancedFlag represents Flags.LoadBalanced. LoadBalancedFlag + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + // nextFlag is the value that the next added flag will have. // // It is used to calculate FlagMask below. It is also the number of @@ -106,6 +115,10 @@ const ( // FlagMask is a bit mask for BitFlags. FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag ) // ToFlags converts the bitset into a Flags struct. @@ -113,6 +126,7 @@ func (f BitFlags) ToFlags() Flags { return Flags{ MostRecent: f&MostRecentFlag != 0, LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, } } @@ -175,9 +189,54 @@ func (c FlagCounter) IntersectionRefs() BitFlags { return intersection } +type destination struct { + addr tcpip.Address + port uint16 +} + +func makeDestination(a tcpip.FullAddress) destination { + return destination{ + a.Addr, + a.Port, + } +} + +// portNode is never empty. When it has no elements, it is removed from the +// map that references it. +type portNode map[destination]FlagCounter + +// intersectionRefs calculates the intersection of flag bit values which affect +// the specified destination. +// +// If no destinations are present, all flag values are returned as there are no +// entries to limit possible flag values of a new entry. +// +// In addition to the intersection, the number of intersecting refs is +// returned. +func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { + intersection := FlagMask + var count int + + for d, f := range p { + if d == dst { + intersection &= f.IntersectionRefs() + count++ + continue + } + // Wildcard destinations affect all destinations for TupleOnly. + if d.addr == anyIPAddress || dst.addr == anyIPAddress { + // Only bitwise and the TupleOnlyFlag. + intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + count++ + } + } + + return intersection, count +} + // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]FlagCounter +type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all FlagCounters. If binding to a specific device, check @@ -186,17 +245,15 @@ type deviceNode map[tcpip.NICID]FlagCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { flagBits := flags.Bits() if bindToDevice == 0 { - // Trying to binding all devices. - if flagBits == 0 { - // Can't bind because the (addr,port) is already bound. - return false - } intersection := FlagMask for _, p := range d { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) + if c == 0 { + continue + } intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -210,16 +267,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.IntersectionRefs() - if intersection&flagBits == 0 { + var c int + intersection, c = p.intersectionRefs(dst) + if c > 0 && intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) intersection &= i - if intersection&flagBits == 0 { + if c > 0 && intersection&flagBits == 0 { return false } } @@ -233,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -247,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -320,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice, dst) { return false } } @@ -342,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() + dst := makeDestination(dest) + // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } return port, nil @@ -357,15 +417,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { return false } + flagBits := flags.Bits() // Reserve port on all network protocols. @@ -381,9 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - n := d[bindToDevice] + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + n := p[dst] n.AddRef(flagBits) - d[bindToDevice] = n + p[dst] = n + d[bindToDevice] = p + } + + return true +} + +// ReserveTuple adds a port reservation for the tuple on all given protocol. +func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { + flagBits := flags.Bits() + dst := makeDestination(dest) + + s.mu.Lock() + defer s.mu.Unlock() + + // It is easier to undo the entire reservation, so if we find that the + // tuple can't be fully added, finish and undo the whole thing. + undo := false + + // Reserve port on all network protocols. + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + + n := p[dst] + if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + // Tuple already exists. + undo = true + } + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + if undo { + // releasePortLocked decrements the counts (rather than setting + // them to zero), so it will undo the incorrect incrementing + // above. + s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + return false } return true @@ -391,12 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.Bits() + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) +} +func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -404,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n, ok := d[bindToDevice] + p, ok := d[bindToDevice] if !ok { continue } - n.refs[flagBits]-- - d[bindToDevice] = n - if n.TotalRefs() == 0 { - delete(d, bindToDevice) + n, ok := p[dst] + if !ok { + continue + } + n.DropRef(flags) + if n.TotalRefs() > 0 { + p[dst] = n + continue } - if len(d) == 0 { - delete(m, addr) + delete(p, dst) + if len(p) > 0 { + continue + } + delete(d, bindToDevice) + if len(d) > 0 { + continue } - if len(m) == 0 { - delete(s.allocatedPorts, desc) + delete(m, addr) + if len(m) > 0 { + continue } + delete(s.allocatedPorts, desc) } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index d6969d050..58db5868c 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -36,6 +36,7 @@ type portReserveTestAction struct { flags Flags release bool device tcpip.NICID + dest tcpip.FullAddress } func TestPortReservation(t *testing.T) { @@ -272,6 +273,54 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind two tuples with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind two tuples", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind wildcard, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard twice with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -280,19 +329,18 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } }) - } } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 24f52b735..6b9a6b316 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -27,6 +27,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,6 +47,7 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", @@ -48,7 +61,9 @@ go_library( "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -78,6 +93,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", @@ -93,7 +109,7 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 05bf62788..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -26,280 +26,321 @@ import ( ) // Connection tracking is used to track and manipulate packets for NAT rules. -// The connection is created for a packet if it does not exist. Every connection -// contains two tuples (original and reply). The tuples are manipulated if there -// is a matching NAT rule. The packet is modified by looking at the tuples in the -// Prerouting and Output hooks. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. + +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 // Direction of the tuple. -type ctDirection int +type direction int const ( - dirOriginal ctDirection = iota + dirOriginal direction = iota dirReply ) -// Status of connection. -// TODO(gvisor.dev/issue/170): Add other states of connection. -type connStatus int - -const ( - connNew connStatus = iota - connEstablished -) - // Manipulation type for the connection. type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) -// connTrackMutable is the manipulatable part of the tuple. -type connTrackMutable struct { - // addr is source address of the tuple. - addr tcpip.Address - - // port is source port of the tuple. - port uint16 - - // protocol is network layer protocol. - protocol tcpip.NetworkProtocolNumber -} - -// connTrackImmutable is the non-manipulatable part of the tuple. -type connTrackImmutable struct { - // addr is destination address of the tuple. - addr tcpip.Address +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +// +// +stateify savable +type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry - // direction is direction (original or reply) of the tuple. - direction ctDirection + tupleID - // port is destination port of the tuple. - port uint16 + // conn is the connection tracking entry this tuple belongs to. + conn *conn - // protocol is transport layer protocol. - protocol tcpip.TransportProtocolNumber + // direction is the direction of the tuple. + direction direction } -// connTrackTuple represents the tuple which is created from the -// packet. -type connTrackTuple struct { - // dst is non-manipulatable part of the tuple. - dst connTrackImmutable - - // src is manipulatable part of the tuple. - src connTrackMutable +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +// +// +stateify savable +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber } -// connTrackTupleHolder is the container of tuple and connection. -type ConnTrackTupleHolder struct { - // conn is pointer to the connection tracking entry. - conn *connTrack - - // tuple is original or reply tuple. - tuple connTrackTuple +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } } -// connTrack is the connection. -type connTrack struct { - // originalTupleHolder contains tuple in original direction. - originalTupleHolder ConnTrackTupleHolder - - // replyTupleHolder contains tuple in reply direction. - replyTupleHolder ConnTrackTupleHolder - - // status indicates connection is new or established. - status connStatus +// conn is a tracked connection. +// +// +stateify savable +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple - // timeout indicates the time connection should be active. - timeout time.Duration + // reply is the tuple in reply direction. It is immutable. + reply tuple - // manip indicates if the packet should be manipulated. + // manip indicates if the packet should be manipulated. It is immutable. manip manipType - // tcb is TCB control block. It is used to keep track of states - // of tcp connection. - tcb tcpconntrack.TCB - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. + // update the state of tcb. It is immutable. tcbHook Hook + + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` + // tcb is TCB control block. It is used to keep track of states + // of tcp connection and is protected by mu. + tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` } -// ConnTrackTable contains a map of all existing connections created for -// NAT rules. -type ConnTrackTable struct { - // connMu protects connTrackTable. - connMu sync.RWMutex +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} - // connTrackTable maintains a map of tuples needed for connection tracking - // for iptables NAT rules. The key for the map is an integer calculated - // using seed, source address, destination address, source port and - // destination port. - CtMap map[uint32]ConnTrackTupleHolder +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable +type ConnTrack struct { // seed is a one-time random value initialized at stack startup - // and is used in calculation of hash key for connection tracking - // table. - Seed uint32 + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 + + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket } -// packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { - var tuple connTrackTuple +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList +} - netHeader := header.IPv4(pkt.NetworkHeader) +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader) if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } - tuple.src.addr = netHeader.SourceAddress() - tuple.src.port = tcpHeader.SourcePort() - tuple.src.protocol = header.IPv4ProtocolNumber - - tuple.dst.addr = netHeader.DestinationAddress() - tuple.dst.port = tcpHeader.DestinationPort() - tuple.dst.protocol = netHeader.TransportProtocol() - - return tuple, nil + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil } -// getReplyTuple creates reply tuple for the given tuple. -func getReplyTuple(tuple connTrackTuple) connTrackTuple { - var replyTuple connTrackTuple - replyTuple.src.addr = tuple.dst.addr - replyTuple.src.port = tuple.dst.port - replyTuple.src.protocol = tuple.src.protocol - replyTuple.dst.addr = tuple.src.addr - replyTuple.dst.port = tuple.src.port - replyTuple.dst.protocol = tuple.dst.protocol - replyTuple.dst.direction = dirReply - - return replyTuple +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn } -// makeNewConn creates new connection. -func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { - var conn connTrack - conn.status = connNew - conn.originalTupleHolder.tuple = tuple - conn.originalTupleHolder.conn = &conn - conn.replyTupleHolder.tuple = replyTuple - conn.replyTupleHolder.conn = &conn +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil, dirOriginal + } - return conn -} + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } + } -// getTupleHash returns hash of the tuple. The fields used for -// generating hash are seed (generated once for stack), source address, -// destination address, source port and destination ports. -func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { - h := jenkins.Sum32(ct.Seed) - h.Write([]byte(tuple.src.addr)) - h.Write([]byte(tuple.dst.addr)) - portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, tuple.src.port) - h.Write([]byte(portBuf)) - binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) - h.Write([]byte(portBuf)) - - return h.Sum32() + return nil, dirOriginal } -// connTrackForPacket returns connTrack for packet. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other -// transport protocols. -func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - var dir ctDirection - tuple, err := packetToTuple(pkt, hook) +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) if err != nil { - return nil, dir - } - - ct.connMu.Lock() - defer ct.connMu.Unlock() - - connTrackTable := ct.CtMap - hash := ct.getTupleHash(tuple) - - var conn *connTrack - switch createConn { - case true: - // If connection does not exist for the hash, create a new - // connection. - replyTuple := getReplyTuple(tuple) - replyHash := ct.getTupleHash(replyTuple) - newConn := makeNewConn(tuple, replyTuple) - conn = &newConn - - // Add tupleHolders to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. - ct.CtMap[hash] = conn.originalTupleHolder - ct.CtMap[replyHash] = conn.replyTupleHolder - default: - tupleHolder, ok := connTrackTable[hash] - if !ok { - return nil, dir - } - - // If this is the reply of new connection, set the connection - // status as ESTABLISHED. - conn = tupleHolder.conn - if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { - conn.status = connEstablished - } - if tupleHolder.conn == nil { - panic("tupleHolder has null connection tracking entry") - } + return nil + } + if hook != Prerouting && hook != Output { + return nil + } - dir = tupleHolder.tuple.dst.direction + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput } - return conn, dir + conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn } -// SetNatInfo will manipulate the tuples according to iptables NAT rules. -func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { - // Get the connection. Connection is always created before this - // function is called. - conn, _ := ct.connTrackForPacket(pkt, hook, false) - if conn == nil { - panic("connection should be created to manipulate tuples.") +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() } - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to - // support changing of address for Prerouting. - - // Change the port as per the iptables rule. This tuple will be used - // to manipulate the packet in HandlePacket. - conn.replyTupleHolder.tuple.src.addr = rt.MinIP - conn.replyTupleHolder.tuple.src.port = rt.MinPort - newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } - // Add the changed tuple to the map. - ct.connMu.Lock() - defer ct.connMu.Unlock() - ct.CtMap[newHash] = conn.replyTupleHolder - if hook == Output { - conn.replyTupleHolder.conn.manip = manipDstOutput + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) } - // Delete the old tuple. - delete(ct.CtMap, replyHash) + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. -func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -308,21 +349,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) // modified. switch dir { case dirOriginal: - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) case dirReply: - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -331,13 +382,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, // modified. For prerouting redirection, we only reach this point // when replying, so packet sources are modified. if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) } else { - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } // Calculate the TCP checksum and set it. @@ -356,33 +407,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, netHeader.SetChecksum(^netHeader.CalculateChecksum()) } -// HandlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// handlePacket will manipulate the port and address of the packet if the +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false } - conn, dir := ct.connTrackForPacket(pkt, hook, false) - // Connection or Rule not found for the packet. - if conn == nil { - return + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Need to support for other transport - // protocols as well. - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return + conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. + if conn == nil { + return true } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return + return false } switch hook { @@ -396,39 +446,161 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle // other tcp states. - var st tcpconntrack.Result - if conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.mu.Lock() + defer conn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return } - // Delete conntrack if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConnTrack(conn) + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return } -} -// deleteConnTrack deletes the connection. -func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { - if conn == nil { + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) +} - tuple := conn.originalTupleHolder.tuple - hash := ct.getTupleHash(tuple) - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} - ct.connMu.Lock() - defer ct.connMu.Unlock() +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} + +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.CtMap, hash) - delete(ct.CtMap, replyHash) + return true } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..bca1d940b 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -301,6 +302,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 4e9b404c8..cbbae4224 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,39 +16,49 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) -// Table names. +// tableID is an index into IPTables.tables. +type tableID int + const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" + natID tableID = iota + mangleID + filterID + numTables ) -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +// Table names. const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" ) +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. return &IPTables{ - tables: map[string]Table{ - TablenameNat: Table{ + tables: [numTables]Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -56,65 +66,71 @@ func DefaultTables() *IPTables { Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - UserChains: map[string]int{}, }, - TablenameMangle: Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, - TablenameFilter: Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, }, - priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, - connections: ConnTrackTable{ - CtMap: make(map[uint32]ConnTrackTupleHolder), - Seed: generateRandUint32(), + connections: ConnTrack{ + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -123,69 +139,59 @@ func DefaultTables() *IPTables { func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, } } -// EmptyNatTable returns a Table with no rules and the filter table chains +// EmptyNATTable returns a Table with no rules and the filter table chains // mapped to HookUnset. -func EmptyNatTable() Table { +func EmptyNATTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + Underflows: [NumHooks]int{ + Forward: HookUnset, }, - UserChains: map[string]int{}, } } -// GetTable returns table by name. +// GetTable returns a table by name. func (it *IPTables) GetTable(name string) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } it.mu.RLock() defer it.mu.RUnlock() - t, ok := it.tables[name] - return t, ok + return it.tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) { - it.mu.Lock() - defer it.mu.Unlock() - it.tables[name] = table -} - -// ModifyTables acquires write-lock and calls fn with internal name-to-table -// map. This function can be used to update multiple tables atomically. -func (it *IPTables) ModifyTables(fn func(map[string]Table)) { +func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } it.mu.Lock() defer it.mu.Unlock() - fn(it.tables) -} - -// GetPriorities returns slice of priorities associated with hook. -func (it *IPTables) GetPriorities(hook Hook) []string { - it.mu.RLock() - defer it.mu.RUnlock() - return it.priorities[hook] + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } + it.modified = true + it.tables[id] = table + return nil } // A chainVerdict is what a table decides should be done with a packet. @@ -209,13 +215,30 @@ const ( // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + it.mu.RLock() + if !it.modified { + it.mu.RUnlock() + return true + } + it.mu.RUnlock() + // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.HandlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.GetPriorities(hook) { - table, _ := it.GetTable(tablename) + it.mu.RLock() + defer it.mu.RUnlock() + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -244,17 +267,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -278,9 +343,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -325,23 +390,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 92e31643e..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -92,7 +92,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -150,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set up conection for matching NAT rule. - // Only the first packet of the connection comes here. - // Other packets will be manipulated in connection tracking. - if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { - ct.SetNatInfo(pkt, rt, hook) - ct.HandlePacket(pkt, hook, gso, r) + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4a6a5c6f1..73274ada9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,67 +78,65 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { - // mu protects tables and priorities. + // mu protects tables, priorities, and modified. mu sync.RWMutex - // tables maps table names to tables. User tables have arbitrary names. mu - // needs to be locked for accessing. - tables map[string]Table + // tables maps tableIDs to tables. Holds builtin tables only, not user + // tables. mu must be locked for accessing. + tables [numTables]Table // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. mu needs to be locked for accessing. - priorities map[Hook][]string + priorities [NumHooks][]tableID + + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + modified bool - connections ConnTrackTable + connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. +// really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int + BuiltinChains [NumHooks]int // Underflows maps builtin chains to their underflow rule in Rules // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} + Underflows [NumHooks]int } // ValidHooks returns a bitmap of the builtin hooks for the given table. func (table *Table) ValidHooks() uint32 { hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } } return hooks } -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -151,6 +149,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber @@ -258,5 +258,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index e28c23d66..9dce11a97 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -469,7 +469,7 @@ type ndpState struct { rtrSolicit struct { // The timer used to send the next router solicitation message. - timer *time.Timer + timer tcpip.Timer // Used to let the Router Solicitation timer know that it has been stopped. // @@ -503,7 +503,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -515,38 +515,38 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time @@ -561,15 +561,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work // cannot be done while holding the NIC's lock. This is effectively the same // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { ndp.nic.mu.Lock() defer ndp.nic.mu.Unlock() @@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } @@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ref: ref, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat prefixState.stableAddr.ref.deprecated = false } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } @@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // @@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa } state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } @@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } @@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { ndp.nic.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ae326b3ab..644ba7c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -36,15 +36,24 @@ import ( ) const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - defaultTimeout = 100 * time.Millisecond - defaultAsyncEventTimeout = time.Second + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 10 * time.Second + + // Extra time to use when waiting for an async event to not occur. + // + // Since a negative check is used to make sure an event did not happen, it is + // okay to use a smaller timeout compared to the positive case since execution + // stall in regards to the monotonic clock will not affect the expected + // outcome. + defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -442,7 +451,7 @@ func TestDADResolve(t *testing.T) { // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } else if want := (tcpip.AddressWithPrefix{}); addr != want { @@ -471,7 +480,7 @@ func TestDADResolve(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { @@ -1169,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1245,14 +1254,14 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1262,14 +1271,14 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1418,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1493,14 +1502,14 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for prefix discovery event") } @@ -1565,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with finite lifetime. @@ -1590,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with a prefix with a lifetime value greater than the @@ -1599,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): } // Receive an RA with 0 lifetime. @@ -1835,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1962,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -1975,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -2081,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { @@ -2180,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -2188,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -2265,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } select { @@ -2273,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2363,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2386,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2398,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { @@ -2407,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event = %+v", e) } - case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2423,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2517,14 +2526,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Prefer the prefix again. // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2546,24 +2555,24 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2711,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2724,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -2984,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second @@ -3070,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3110,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3124,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3156,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { @@ -3165,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3295,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): + case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3439,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): } // Wait for the invalidation event. @@ -3448,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3504,12 +3513,12 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -3672,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -3770,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3837,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -3863,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -4030,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4149,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4251,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -4277,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4285,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for auto gen addr event") } } @@ -4869,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) select { case <-ndpDisp.routerC: t.Error("unexpected router event") @@ -5172,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) remaining-- } for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout) - waitForPkt(2 * defaultAsyncEventTimeout) + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) + waitForPkt(defaultAsyncPositiveEventTimeout) } else { - waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) } } // Make sure no more RS. if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout) + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) } // Make sure the counter got properly @@ -5305,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok = e.ReadContext(ctx); ok { t.Fatal("should not have sent more than one RS message") @@ -5319,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") @@ -5332,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + waitForPkt(delay + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") @@ -5344,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..fea0ce7e8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this @@ -1358,16 +1374,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // TransportHeader is nil only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..c477e31d8 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -84,6 +84,16 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,7 +25,7 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. @@ -78,6 +79,10 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new @@ -102,14 +107,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NatDone: pk.NatDone, } } - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..9e1b2d25f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -329,8 +333,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -341,6 +344,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -436,6 +449,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a2190341c..a6faa22c2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -471,6 +472,14 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption } // UniqueID is an abstract generator of unique identifiers. @@ -683,6 +692,16 @@ func New(opts Options) *Stack { tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. @@ -709,6 +728,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -782,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -1033,14 +1058,14 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { // Remove routes in-place. n tracks the number of routes written. n := 0 for i, r := range s.routeTable { + s.routeTable[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - if i > n { - s.routeTable[n] = r - } + s.routeTable[n] = r n++ } } + s.routeTable = s.routeTable[:n] return nic.remove() @@ -1076,6 +1101,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1107,6 +1137,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1408,6 +1439,12 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) +} + // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index ffef9bc2c..7657a4101 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3305,7 +3305,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { @@ -3338,3 +3338,83 @@ func TestDoDADWhenNICEnabled(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) } } + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 118b449d5..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -221,6 +221,18 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + return nil + } + + return multiPortEp.singleCheckEndpoint(flags) +} + // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() @@ -289,6 +301,17 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { + return err + } + } + + return nil +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -380,7 +403,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() + bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. @@ -395,6 +418,22 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } + + return nil +} + // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() @@ -406,7 +445,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports ep.endpoints[len(ep.endpoints)-1] = nil ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] - ep.flags.DropRef(flags.Bits()) + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) break } } @@ -439,6 +478,28 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol + } + + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil + } + + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b7b227328..21aafb0a2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -203,6 +203,31 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a @@ -316,6 +341,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -549,6 +596,28 @@ type Endpoint interface { SetOwner(owner PacketOwner) } +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) +} + // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -585,85 +654,108 @@ type WriteOptions struct { type SockOptBool int const ( - // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether - // datagram sockets are allowed to send packets to a broadcast address. + // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify + // whether datagram sockets are allowed to send packets to a broadcast + // address. BroadcastOption SockOptBool = iota - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be - // held until segments are full by the TCP transport protocol. + // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be held until segments are full by the TCP transport + // protocol. CorkOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. + // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be sent out immediately by the transport protocol. For + // TCP, it determines if the Nagle algorithm is on or off. DelayOption - // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether - // TCP keepalive is enabled for this socket. + // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to + // specify whether TCP keepalive is enabled for this socket. KeepaliveEnabledOption - // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether - // multicast packets sent over a non-loopback interface will be looped back. + // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to + // specify whether multicast packets sent over a non-loopback interface + // will be looped back. MulticastLoopOption - // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether - // SCM_CREDENTIALS socket control messages are enabled. + // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify + // whether UDP checksum is disabled for this socket. + NoChecksumOption + + // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify + // whether SCM_CREDENTIALS socket control messages are enabled. // // Only supported on Unix sockets. PasscredOption - // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption - // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the - // IPV6_TCLASS ancillary message is passed with incoming packets. + // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to + // specify if the IPV6_TCLASS ancillary message is passed with incoming + // packets. ReceiveTClassOption - // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS - // ancillary message is passed with incoming packets. + // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify + // if the TOS ancillary message is passed with incoming packets. ReceiveTOSOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify - // if more inforamtion is provided with incoming packets such - // as interface index and address. + // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to + // specify if more inforamtion is provided with incoming packets such as + // interface index and address. ReceiveIPPacketInfoOption - // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() - // should allow reuse of local address. + // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to + // specify whether Bind() should allow reuse of local address. ReuseAddressOption - // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets - // to be bound to an identical socket address. + // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit + // multiple sockets to be bound to an identical socket address. ReusePortOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. + // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify + // whether an IPv6 socket is to be restricted to sending and receiving + // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( - // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number - // of un-ACKed TCP keepalives that will be sent before the connection is - // closed. + // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to + // specify the number of un-ACKed TCP keepalives that will be sent + // before the connection is closed. KeepaliveCountOption SockOptInt = iota - // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS // for all subsequent outgoing IPv4 packets from the endpoint. IPv4TOSOption - // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS - // for all subsequent outgoing IPv6 packets from the endpoint. + // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to + // specify TOS for all subsequent outgoing IPv6 packets from the + // endpoint. IPv6TrafficClassOption - // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current - // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the + // current Maximum Segment Size(MSS) value as specified using the + // TCP_MAXSEG option. MaxSegOption - // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default - // TTL value for multicast messages. The default is 1. + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control + // the default TTL value for multicast messages. The default is 1. MulticastTTLOption // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the @@ -682,26 +774,45 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop - // limit value for unicast messages. The default is protocol specific. + // TTLOption is used by SetSockOptInt/GetSockOptInt to control the + // default TTL/hop limit value for unicast messages. The default is + // protocol specific. // // A zero value indicates the default. TTLOption - // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of - // SYN retransmits that TCP should send before aborting the attempt to - // connect. It cannot exceed 255. + // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify + // the number of SYN retransmits that TCP should send before aborting + // the attempt to connect. It cannot exceed 255. // // NOTE: This option is currently only stubbed out and is no-op. TCPSynCountOption - // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size - // of the advertised window to this value. + // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound + // the size of the advertised window to this value. // // NOTE: This option is currently only stubed out and is a no-op TCPWindowClampOption ) +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota + + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont + + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo + + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) + // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} @@ -740,7 +851,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -813,7 +924,11 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// IPPacketInfo is the message struture for IP_PKTINFO. +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int + +// IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable type IPPacketInfo struct { @@ -1198,6 +1313,12 @@ type UDPStats struct { // PacketSendErrors is the number of datagrams failed to be sent. PacketSendErrors *StatCounter + + // ChecksumErrors is the number of datagrams dropped due to bad checksums. + ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. @@ -1241,6 +1362,9 @@ type ReceiveErrors struct { // ClosedReceiver is the number of received packets dropped because // of receiving endpoint state being closed. ClosedReceiver StatCounter + + // ChecksumErrors is the number of packets dropped due to bad checksums. + ChecksumErrors StatCounter } // SendErrors collects packet send errors within the transport layer for diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 7f172f978..f32d58091 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -20,7 +20,7 @@ package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) @@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..f1dd7c310 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,54 +15,54 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) -// cancellableTimerInstance is a specific instance of CancellableTimer. +// jobInstance is a specific instance of Job. // -// Different instances are created each time CancellableTimer is Reset so each -// timer has its own earlyReturn signal. This is to address a bug when a -// CancellableTimer is stopped and reset in quick succession resulting in a -// timer instance's earlyReturn signal being affected or seen by another timer -// instance. +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. // // Consider the following sceneario where timer instances share a common // earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a // lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second // (B), third (C), and fourth (D) instance of the timer firing, respectively): // T1: Obtain L -// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T1: Create a new Job w/ lock L (create instance A) // T2: instance A fires, blocked trying to obtain L. // T1: Attempt to stop instance A (set earlyReturn = true) -// T1: Reset timer (create instance B) +// T1: Schedule timer (create instance B) // T3: instance B fires, blocked trying to obtain L. // T1: Attempt to stop instance B (set earlyReturn = true) -// T1: Reset timer (create instance C) +// T1: Schedule timer (create instance C) // T4: instance C fires, blocked trying to obtain L. // T1: Attempt to stop instance C (set earlyReturn = true) -// T1: Reset timer (create instance D) +// T1: Schedule timer (create instance D) // T5: instance D fires, blocked trying to obtain L. // T1: Release L // -// Now that T1 has released L, any of the 4 timer instances can take L and check -// earlyReturn. If the timers simply check earlyReturn and then do nothing -// further, then instance D will never early return even though it was not -// requested to stop. If the timers reset earlyReturn before early returning, -// then all but one of the timers will do work when only one was expected to. -// If CancellableTimer resets earlyReturn when resetting, then all the timers +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers // will fire (again, when only one was expected to). // // To address the above concerns the simplest solution was to give each timer // its own earlyReturn signal. -type cancellableTimerInstance struct { - timer *time.Timer +type jobInstance struct { + timer Timer // Used to inform the timer to early return when it gets stopped while the // lock the timer tries to obtain when fired is held (T1 is a goroutine that // tries to cancel the timer and T2 is the goroutine that handles the timer // firing): - // T1: Obtain the lock, then call StopLocked() + // T1: Obtain the lock, then call Cancel() // T2: timer fires, and gets blocked on obtaining the lock // T1: Releases lock // T2: Obtains lock does unintended work @@ -73,27 +73,33 @@ type cancellableTimerInstance struct { earlyReturn *bool } -// stop stops the timer instance t from firing if it hasn't fired already. If it +// stop stops the job instance j from firing if it hasn't fired already. If it // has fired and is blocked at obtaining the lock, earlyReturn will be set to // true so that it will early return when it obtains the lock. -func (t *cancellableTimerInstance) stop() { - if t.timer != nil { - t.timer.Stop() - *t.earlyReturn = true +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true } } -// CancellableTimer is a timer that does some work and can be safely cancelled -// when it fires at the same time some "related work" is being done. +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. // -// Note, it is not safe to copy a CancellableTimer as its timer instance creates -// a closure over the address of the CancellableTimer. -type CancellableTimer struct { +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { + _ sync.NoCopy + + // The clock used to schedule the backing timer + clock Clock + // The active instance of a cancellable timer. - instance cancellableTimerInstance + instance jobInstance // locker is the lock taken by the timer immediately after it fires and must // be held when attempting to stop the timer. @@ -110,75 +116,91 @@ type CancellableTimer struct { fn func() } -// StopLocked prevents the Timer from firing if it has not fired already. +// Cancel prevents the Job from executing if it has not executed already. // -// If the timer is blocked on obtaining the t.locker lock when StopLocked is -// called, it will early return instead of calling t.fn. +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. // // Note, t will be modified. // -// t.locker MUST be locked. -func (t *CancellableTimer) StopLocked() { - t.instance.stop() +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() // Nothing to do with the stopped instance anymore. - t.instance = cancellableTimerInstance{} + j.instance = jobInstance{} } -// Reset changes the timer to expire after duration d. +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. // -// Note, t will be modified. +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. // -// Reset should only be called on stopped or expired timers. To be safe, callers -// should always call StopLocked before calling Reset. -func (t *CancellableTimer) Reset(d time.Duration) { +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { // Create a new instance. earlyReturn := false // Capture the locker so that updating the timer does not cause a data race // when a timer fires and tries to obtain the lock (read the timer's locker). - locker := t.locker - t.instance = cancellableTimerInstance{ - timer: time.AfterFunc(d, func() { + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { locker.Lock() defer locker.Unlock() if earlyReturn { // If we reach this point, it means that the timer fired while another - // goroutine called StopLocked while it had the lock. Simply return - // here and do nothing further. + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. earlyReturn = false return } - t.fn() + j.fn() }), earlyReturn: &earlyReturn, } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - -// NewCancellableTimer returns an unscheduled CancellableTimer with the given -// locker and fn. -// -// fn MUST NOT attempt to lock locker. -// -// Callers must call Reset to schedule the timer to fire. -func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { - return &CancellableTimer{locker: locker, fn: fn} +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index b4940e397..a82384c49 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -28,8 +28,8 @@ const ( longDuration = 1 * time.Second ) -func TestCancellableTimerReassignment(t *testing.T) { - var timer tcpip.CancellableTimer +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock var wg sync.WaitGroup var lock sync.Mutex @@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = *tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { wg.Done() }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) lock.Unlock() }() } wg.Wait() } -func TestCancellableTimerFire(t *testing.T) { +func TestJobExecution(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) lock.Lock() - timer.StopLocked() + job.Cancel() lock.Unlock() - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { } } -func TestCancellableTimerResetFromShortDuration(t *testing.T) { +func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() // Wait for timer to fire if it wasn't correctly stopped. @@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { case <-time.After(middleDuration): } - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { } } -func TestCancellableTimerImmediatelyStop(t *testing.T) { +func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() } @@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { } } -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() for i := 0; i < 10; i++ { - timer.Reset(middleDuration) + job.Schedule(middleDuration) lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration * 2) - timer.StopLocked() + job.Cancel() lock.Unlock() } @@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() @@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { } } -func TestManyCancellableTimerResetUnderLock(t *testing.T) { +func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..4612be4e7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } @@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index baf08eda6..0e46e6355 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,6 +25,8 @@ package packet import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -43,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -71,11 +76,17 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } @@ -132,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -158,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(b/129292371): Implement. return 0, nil, tcpip.ErrInvalidOptionValue @@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } @@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } @@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // HandlePacket implements stack.PacketEndpoint.HandlePacket. @@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a406d815e..f85a68554 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,6 +26,8 @@ package raw import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -61,21 +63,23 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` rcvList rawPacketList - rcvBufSizeMax int `state:".(int)"` rcvBufSize int + rcvBufSizeMax int `state:".(int)"` rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - connected bool - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + connected bool + bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` @@ -91,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if netProto != header.IPv4ProtocolNumber { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -103,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, + } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default } // Unassociated endpoints are write-only and users call Write() with IP @@ -168,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -201,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -244,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -305,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrNoRoute } - // We don't support IPv6 yet, so this has to be an IPv4 address. - if len(opts.To.Addr) != header.IPv4AddressSize { - e.mu.RUnlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) @@ -340,17 +351,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch e.NetProto { - case header.IPv4ProtocolNumber: - if !e.associated { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - break + if e.hdrIncluded { + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { + return 0, nil, err } - + } else { hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, @@ -359,9 +366,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, }); err != nil { return 0, nil, err } - - default: - return 0, nil, tcpip.ErrUnknownProtocol } return int64(len(payloadBytes)), nil, nil @@ -386,11 +390,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // We don't support IPv6 yet. - if len(addr.Addr) != header.IPv4AddressSize { - return tcpip.ErrInvalidEndpointState - } - nic := addr.NIC if e.bound { if e.BindNICID == 0 { @@ -456,14 +455,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // Callers must provide an IPv4 address or no network address (for - // binding to a NIC, but not an address). - if len(addr.Addr) != 0 && len(addr.Addr) != 4 { - return tcpip.ErrInvalidEndpointState - } - // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -513,17 +506,69 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + e.rcvMu.Lock() + e.rcvBufSizeMax = v + e.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -543,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -563,7 +614,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil @@ -582,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -627,16 +685,25 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - headers := append(buffer.View(nil), pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) - combinedVV := headers.ToVectorisedView() + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) + headers = append(headers, pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index e26f01fae..18ff89ffc 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -76,7 +76,7 @@ go_library( ) go_test( - name = "tcp_test", + name = "tcp_x_test", size = "medium", srcs = [ "dual_stack_test.go", @@ -86,6 +86,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", @@ -115,3 +116,11 @@ go_test( "//pkg/tcpip/seqnum", ], ) + +go_test( + name = "tcp_test", + size = "small", + srcs = ["timer_test.go"], + library = ":tcp", + deps = ["//pkg/sleep"], +) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index ad197e8db..6e00e5526 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -199,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. The endpoint is returned -// with n.mu held. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -222,33 +220,12 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() - // Create sender and receiver. - // - // The receiver at least temporarily has a zero receive window scale, - // but the caller may change it (before starting the protocol loop). - n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) - n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize())) // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - // Lock the endpoint before registering to ensure that no out of - // band changes are possible due to incoming packets etc till - // the endpoint is done initializing. - n.mu.Lock() - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { - n.mu.Unlock() - n.Close() - return nil, err - } - - n.isRegistered = true - n.registeredReusePort = n.reusePort - - return n, nil + return n } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -259,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) - if err != nil { - return nil, err - } + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. @@ -270,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) @@ -290,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // to the newly created endpoint. l.listenEP.propagateInheritableOptionsLocked(ep) + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + // Perform the 3-way handshake. - h := newPassiveHandshake(ep, ep.rcv.rcvWnd, isn, irs, opts, deferAccept) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + ep.notifyAborted() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -380,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // Precondition: e.mu and n.mu must be held. func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -536,6 +570,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } + iss := s.ackNumber - 1 + irs := s.sequenceNumber - 1 + // Since SYN cookies are in use this is potentially an ACK to a // SYN-ACK we sent but don't have a half open connection state // as cookies are being used to protect against a potential SYN @@ -546,7 +583,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // when under a potential syn flood attack. // // Validate the cookie. - data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) + data, ok := ctx.isCookieValid(s.id, iss, irs) if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() @@ -571,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } - // Propagate any inheritable options from the listening endpoint - // to the newly created endpoint. - e.propagateInheritableOptionsLocked(n) + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + n.isRegistered = true // clear the tsOffset for the newly created // endpoint as the Timestamp was already @@ -589,10 +644,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - // We do not use transitionToStateEstablishedLocked here as there is - // no handshake state available when doing a SYN cookie based accept. n.isConnectNotified = true - n.setEndpointState(StateEstablished) + n.transitionToStateEstablishedLocked(&handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + }) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 7da93dcc4..1798510bc 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } } // Wait for notification. @@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error { // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, func() { - resendWaker.Assert() - }) + rt := time.AfterFunc(timeOut, resendWaker.Assert) defer rt.Stop() // Set up the wakers. @@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -995,24 +999,22 @@ func (e *endpoint) completeWorkerLocked() { // transitionToStateEstablisedLocked transitions a given endpoint // to an established state using the handshake parameters provided. -// It also initializes sender/receiver if required. +// It also initializes sender/receiver. func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { - if e.snd == nil { - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - } - if e.rcv == nil { - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // Bootstrap the auto tuning algorithm. Starting at zero will - // result in a really large receive window after the first auto - // tuning adjustment. - e.rcvAutoParams.prevCopied = int(h.rcvWnd) - e.rcvListMu.Unlock() - } + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + e.setEndpointState(StateEstablished) } @@ -1052,8 +1054,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { panic("current endpoint not removed from demuxer, enqueing segments to itself") } - if ep.(*endpoint).enqueueSegment(s) { - ep.(*endpoint).newSegmentWaker.Assert() + if ep := ep.(*endpoint); ep.enqueueSegment(s) { + ep.newSegmentWaker.Assert() } } @@ -1122,7 +1124,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState() == StateClose || e.EndpointState() == StateError { + if e.EndpointState().closed() { return nil } s := e.segmentQueue.dequeue() @@ -1442,9 +1444,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { - closeWaker.Assert() - }) + closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } } @@ -1462,7 +1462,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return err } } - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + if !e.EndpointState().closed() { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1518,6 +1518,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. cleanupOnError := func(err *tcpip.Error) { + e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { e.resetConnectionLocked(err) @@ -1527,7 +1528,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } loop: - for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { + for { + switch e.EndpointState() { + case StateTimeWait, StateClose, StateError: + break loop + } + e.mu.Unlock() v, _ := s.Fetch(true) e.mu.Lock() @@ -1570,11 +1576,14 @@ loop: reuseTW = e.doTimeWait() } - // Mark endpoint as closed. - if e.EndpointState() != StateError { - e.transitionToStateCloseLocked() + // Handle any StateError transition from StateTimeWait. + if e.EndpointState() == StateError { + cleanupOnError(nil) + return nil } + e.transitionToStateCloseLocked() + // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 047704c80..98aecab9e 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -15,6 +15,8 @@ package tcp import ( + "encoding/binary" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -66,89 +68,68 @@ func (q *epQueue) empty() bool { // processor is responsible for processing packets queued to a tcp endpoint. type processor struct { epQ epQueue + sleeper sleep.Sleeper newEndpointWaker sleep.Waker closeWaker sleep.Waker - id int - wg sync.WaitGroup -} - -func newProcessor(id int) *processor { - p := &processor{ - id: id, - } - p.wg.Add(1) - go p.handleSegments() - return p } func (p *processor) close() { p.closeWaker.Assert() } -func (p *processor) wait() { - p.wg.Wait() -} - func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) p.newEndpointWaker.Assert() } -func (p *processor) handleSegments() { - const newEndpointWaker = 1 - const closeWaker = 2 - s := sleep.Sleeper{} - s.AddWaker(&p.newEndpointWaker, newEndpointWaker) - s.AddWaker(&p.closeWaker, closeWaker) - defer s.Done() +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + for { - id, ok := s.Fetch(true) - if ok && id == closeWaker { - p.wg.Done() - return + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break } - for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } if ep.segmentQueue.empty() { continue } - // If socket has transitioned out of connected state - // then just let the worker handle the packet. + // If socket has transitioned out of connected state then just let the + // worker handle the packet. // - // NOTE: We read this outside of e.mu lock which means - // that by the time we get to handleSegments the - // endpoint may not be in ESTABLISHED. But this should - // be fine as all normal shutdown states are handled by - // handleSegments and if the endpoint moves to a - // CLOSED/ERROR state then handleSegments is a noop. - if ep.EndpointState() != StateEstablished { - ep.newSegmentWaker.Assert() - continue - } - - if !ep.mu.TryLock() { - ep.newSegmentWaker.Assert() - continue - } - // If the endpoint is in a connected state then we do - // direct delivery to ensure low latency and avoid - // scheduler interactions. - if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { - // Send any active resets if required. - if err != nil { + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) } - ep.notifyProtocolGoroutine(notifyTickleWorker) ep.mu.Unlock() - continue - } - - if !ep.segmentQueue.empty() { - p.epQ.enqueue(ep) + } else { + ep.newSegmentWaker.Assert() } - - ep.mu.Unlock() } } } @@ -159,31 +140,36 @@ func (p *processor) handleSegments() { // hash of the endpoint id to ensure that delivery for the same endpoint happens // in-order. type dispatcher struct { - processors []*processor + processors []processor seed uint32 -} - -func newDispatcher(nProcessors int) *dispatcher { - processors := []*processor{} - for i := 0; i < nProcessors; i++ { - processors = append(processors, newProcessor(i)) - } - return &dispatcher{ - processors: processors, - seed: generateRandUint32(), + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) } } func (d *dispatcher) close() { - for _, p := range d.processors { - p.close() + for i := range d.processors { + d.processors[i].close() } } func (d *dispatcher) wait() { - for _, p := range d.processors { - p.wait() - } + d.wg.Wait() } func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -231,20 +217,18 @@ func generateRandUint32() uint32 { if _, err := rand.Read(b); err != nil { panic(err) } - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return binary.LittleEndian.Uint32(b) } func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { - payload := []byte{ - byte(id.LocalPort), - byte(id.LocalPort >> 8), - byte(id.RemotePort), - byte(id.RemotePort >> 8)} + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) h := jenkins.Sum32(d.seed) - h.Write(payload) + h.Write(payload[:]) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) - return d.processors[h.Sum32()%uint32(len(d.processors))] + return &d.processors[h.Sum32()%uint32(len(d.processors))] } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6e4d607da..0f7487963 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -396,7 +396,8 @@ type endpoint struct { mu sync.Mutex `state:"nosave"` ownedByUser uint32 - // state must be read/set using the EndpointState()/setEndpointState() methods. + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -405,8 +406,8 @@ type endpoint struct { origEndpointState EndpointState `state:"nosave"` isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -415,10 +416,14 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags + boundDest tcpip.FullAddress // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -426,7 +431,7 @@ type endpoint struct { // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -462,13 +467,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - - // registeredReusePort is set if the current endpoint registration was - // done with SO_REUSEPORT enabled. - registeredReusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -488,7 +486,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -838,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -1025,15 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} } // Mark endpoint as closed. @@ -1091,17 +1087,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1213,9 +1209,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1225,7 +1239,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1235,7 +1248,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1522,12 +1534,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { case tcpip.ReuseAddressOption: e.LockUser() - e.reuseAddr = v + e.portFlags.TupleOnly = v e.UnlockUser() case tcpip.ReusePortOption: e.LockUser() - e.reusePort = v + e.portFlags.LoadBalanced = v e.UnlockUser() case tcpip.V6OnlyOption: @@ -1585,6 +1597,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1781,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } @@ -1831,14 +1853,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.ReuseAddressOption: e.LockUser() - v := e.reuseAddr + v := e.portFlags.TupleOnly e.UnlockUser() return v, nil case tcpip.ReusePortOption: e.LockUser() - v := e.reusePort + v := e.portFlags.LoadBalanced e.UnlockUser() return v, nil @@ -1855,6 +1877,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil + case tcpip.MulticastLoopOption: + return true, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -1889,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1922,6 +1952,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.UnlockUser() return v, nil + case tcpip.MulticastTTLOption: + return 1, nil + default: return -1, tcpip.ErrUnknownProtocolOption } @@ -1931,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err + return e.takeLastError() case *tcpip.BindToDeviceOption: e.LockUser() @@ -2085,8 +2114,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -2094,11 +2121,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } - e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2122,40 +2148,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { - case nil: - // Port picking successful. Save the details of - // the selected port. - e.ID = id - e.boundBindToDevice = e.bindToDevice - e.registeredReusePort = e.reusePort - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) - e.isPortReserved = false - } - e.isRegistered = true e.setEndpointState(StateConnecting) e.route = r.Clone() @@ -2334,13 +2353,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) @@ -2427,16 +2445,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - flags := ports.Flags{ - LoadBalanced: e.reusePort, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = flags + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port @@ -2444,7 +2459,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Any failures beyond this point must remove the port registration. defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 @@ -2467,6 +2482,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2531,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index cbb779666..abf1ac5c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() { if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - - if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { - panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") - } } // saveAcceptedChan is invoked by stateify. @@ -191,21 +187,28 @@ func (e *endpoint) Resume(s *stack.Stack) { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - if len(e.BindAddr) == 0 { - e.BindAddr = e.ID.LocalAddress + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) } - addr := e.BindAddr - port := e.ID.LocalPort - if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { - panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) } switch { @@ -277,17 +280,7 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() case epState == StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.setEndpointState(StateClose) - tcpip.AsyncLoading.Done() - }() - } + e.isPortReserved = false e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 73b8a6782..b34e47bbd 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -71,34 +71,36 @@ const ( DefaultSynRetries = 6 ) -// SACKEnabled option can be used to enable SACK support in the TCP -// protocol. See: https://tools.ietf.org/html/rfc2018. +const ( + ccReno = "reno" + ccCubic = "cubic" +) + +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool -// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. type DelayEnabled bool -// SendBufferSizeOption allows the default, min and max send buffer sizes for -// TCP endpoints to be queried or configured. +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. type SendBufferSizeOption struct { Min int Default int Max int } -// ReceiveBufferSizeOption allows the default, min and max receive buffer size -// for TCP endpoints to be queried or configured. +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. type ReceiveBufferSizeOption struct { Min int Default int Max int } -const ( - ccReno = "reno" - ccCubic = "cubic" -) - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -172,7 +174,7 @@ type protocol struct { maxRetries uint32 synRcvdCount synRcvdCounter synRetries uint8 - dispatcher *dispatcher + dispatcher dispatcher } // Number returns the tcp protocol number. @@ -513,18 +515,27 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + p := protocol{ + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultSendBufferSize, + Max: MaxBufferSize, + }, + recvBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultReceiveBufferSize, + Max: MaxBufferSize, + }, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, - dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index acacb42e4..5862c32f2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -833,25 +833,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } - segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - // If the entire segment cannot be accomodated in the receiver - // advertized window, skip splitting and sending of the segment. - // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() - // - // Linux checks this for all segment transmits not triggered - // by a probe timer. On this condition, it defers the segment - // split and transmit to a short probe timer. - // ref: include/net/tcp.h::tcp_check_probe_timer() - // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() - // - // Instead of defining a new transmit timer, we attempt to split the - // segment right here if there are no pending segments. - // If there are pending segments, segment transmits are deferred - // to the retransmit timer handler. - if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { - return false - } - if !seg.sequenceNumber.LessThan(end) { return false } @@ -861,14 +842,48 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return false } - // The segment size limit is computed as a function of sender congestion - // window and MSS. When sender congestion window is > 1, this limit can - // be larger than MSS. Ensure that the currently available send space - // is not greater than minimum of this limit and MSS. + // If the whole segment or at least 1MSS sized segment cannot + // be accomodated in the receiver advertized window, skip + // splitting and sending of the segment. ref: + // net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered by + // a probe timer. On this condition, it defers the segment split + // and transmit to a short probe timer. + // + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split + // the segment right here if there are no pending segments. If + // there are pending segments, segment transmits are deferred to + // the retransmit timer handler. + if s.sndUna != s.sndNxt { + switch { + case available >= seg.data.Size(): + // OK to send, the whole segments fits in the + // receiver's advertised window. + case available >= s.maxPayloadSize: + // OK to send, at least 1 MSS sized segment fits + // in the receiver's advertised window. + default: + return false + } + } + + // The segment size limit is computed as a function of sender + // congestion window and MSS. When sender congestion window is > + // 1, this limit can be larger than MSS. Ensure that the + // currently available send space is not greater than minimum of + // this limit and MSS. if available > limit { available = limit } - if available > s.maxPayloadSize { + + // If GSO is not in use then cap available to + // maxPayloadSize. When GSO is in use the gVisor GSO logic or + // the host GSO logic will cap the segment to the correct size. + if s.ep.gso == nil && available > s.maxPayloadSize { available = s.maxPayloadSize } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 5fe23113b..b9993ce1a 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -50,7 +50,7 @@ func TestFastRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -90,14 +90,14 @@ func TestFastRecovery(t *testing.T) { // Wait before checking metrics. metricPollFn := func() error { if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) } return nil } @@ -128,10 +128,10 @@ func TestFastRecovery(t *testing.T) { // Wait before checking metrics. metricPollFn = func() error { if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) } return nil } @@ -215,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } expected := tcp.InitialCwnd @@ -257,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -362,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -471,11 +471,11 @@ func TestRetransmit(t *testing.T) { // MTU size though. half := data[:len(data)/2] if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } half = data[len(data)/2:] if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -508,23 +508,23 @@ func TestRetransmit(t *testing.T) { metricPollFn := func() error { if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) } if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) } if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) } return nil diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index ace79b7b2..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -47,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } @@ -400,7 +400,7 @@ func TestSACKRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Do slow start for a few iterations. @@ -454,7 +454,7 @@ func TestSACKRecovery(t *testing.T) { } for _, s := range stats { if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %v, want = %v", s.name, got, want) + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) } } return nil @@ -529,19 +529,19 @@ func TestSACKRecovery(t *testing.T) { // In SACK recovery only the first segment is fast retransmitted when // entering recovery. if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) } if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) } if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) } if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) } return nil } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6ef32a1b3..e67ec42b1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -57,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Register for notification, then start connection attempt. @@ -66,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) { defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } // Close the connection, wait for completion. @@ -75,21 +75,21 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) + t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) } // Call Connect again to retreive the handshake failure status // and stats updates. if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted) + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) } if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) } if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } } @@ -102,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) } } @@ -115,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) + t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) } } @@ -129,20 +129,20 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { stats := c.Stack().Stats() ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) } } @@ -156,10 +156,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) } } @@ -170,16 +170,16 @@ func TestTCPResetsSentIncrement(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } want := stats.TCP.SegmentsSent.Value() + 1 if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -213,7 +213,7 @@ func TestTCPResetsSentIncrement(t *testing.T) { metricPollFn := func() error { if got := stats.TCP.ResetsSent.Value(); got != want { - return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) } return nil } @@ -292,7 +292,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err) + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) } c.EP.Close() @@ -355,7 +355,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { }) if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) } } @@ -379,7 +379,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { }) if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) } c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) } @@ -403,7 +403,7 @@ func TestNonBlockingClose(t *testing.T) { t0 := time.Now() ep.Close() if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) + t.Fatalf("Took too long to close: %s", diff) } } @@ -415,7 +415,7 @@ func TestConnectResetAfterClose(t *testing.T) { // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -497,11 +497,11 @@ func TestCurrentConnectedIncrement(t *testing.T) { c.EP = nil if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) } gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected) + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) } ep.Close() @@ -524,10 +524,10 @@ func TestCurrentConnectedIncrement(t *testing.T) { }) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected) + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) } // Ack and send FIN as well. @@ -556,10 +556,10 @@ func TestCurrentConnectedIncrement(t *testing.T) { time.Sleep(1200 * time.Millisecond) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -575,7 +575,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { c.EP = nil if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) } // Send a FIN for ESTABLISHED --> CLOSED-WAIT @@ -603,7 +603,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { time.Sleep(10 * time.Millisecond) if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) } // Close the application endpoint for CLOSE_WAIT --> LAST_ACK @@ -620,7 +620,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ) if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Pause the endpoint`s protocolMainLoop. @@ -657,15 +657,15 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { // Expect the endpoint to be closed. if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got) + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) } if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } // Check if the endpoint was moved to CLOSED and netstack a reset in @@ -691,7 +691,7 @@ func TestSimpleReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -714,7 +714,7 @@ func TestSimpleReceive(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -781,7 +781,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { // Start connection attempt to IPv4 address. if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet with our user supplied MSS. @@ -842,7 +842,7 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { // Start connection attempt to IPv6 address. if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet with our user supplied MSS. @@ -1239,7 +1239,7 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -1251,7 +1251,7 @@ func TestConnectBindToDevice(t *testing.T) { ), ) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) } tcpHdr := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) @@ -1270,7 +1270,7 @@ func TestConnectBindToDevice(t *testing.T) { c.GetPacket() if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) } }) } @@ -1291,7 +1291,7 @@ func TestRstOnSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted) + t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) } // Receive SYN packet. @@ -1332,7 +1332,7 @@ func TestRstOnSynSent(t *testing.T) { } if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) } // Due to the RST the endpoint should be in an error state. @@ -1352,7 +1352,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send second half of data first, with seqnum 3 ahead of expected. @@ -1379,7 +1379,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send the first 3 bytes now. @@ -1406,7 +1406,7 @@ func TestOutOfOrderReceive(t *testing.T) { } continue } - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } read = append(read, v...) @@ -1436,7 +1436,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send 100 packets before the actual one that is expected. @@ -1513,7 +1513,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -1556,7 +1556,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // This final ACK should be ignored because an ACK on a reset doesn't mean @@ -1582,7 +1582,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -1624,7 +1624,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Cause a RST to be generated by closing the read end now since we have @@ -1643,7 +1643,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // The ACK to the FIN should now be rejected since the connection has been @@ -1665,19 +1665,19 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) } } @@ -1693,7 +1693,7 @@ func TestFullWindowReceive(t *testing.T) { _, _, err := c.EP.Read(nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } // Fill up the window. @@ -1728,7 +1728,7 @@ func TestFullWindowReceive(t *testing.T) { // Receive data and check it. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -1737,7 +1737,7 @@ func TestFullWindowReceive(t *testing.T) { var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) } // Check that we get an ACK for the newly non-zero window. @@ -1760,7 +1760,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1768,7 +1768,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send 3 bytes, check that the peer acknowledges them. @@ -1832,7 +1832,7 @@ func TestNoWindowShrinking(t *testing.T) { for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } read = append(read, v...) @@ -1866,7 +1866,7 @@ func TestSimpleSend(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received. @@ -1908,7 +1908,7 @@ func TestZeroWindowSend(t *testing.T) { _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) if err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check if we got a zero-window probe. @@ -1976,7 +1976,7 @@ func TestScaledWindowConnect(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -2008,7 +2008,7 @@ func TestNonScaledWindowConnect(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xffff, @@ -2036,21 +2036,21 @@ func TestScaledWindowAccept(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -2068,7 +2068,7 @@ func TestScaledWindowAccept(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -2081,7 +2081,7 @@ func TestScaledWindowAccept(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -2109,21 +2109,21 @@ func TestNonScaledWindowAccept(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN @@ -2142,7 +2142,7 @@ func TestNonScaledWindowAccept(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -2155,7 +2155,7 @@ func TestNonScaledWindowAccept(t *testing.T) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received, and that advertised window is 0xffff, @@ -2244,7 +2244,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { for sz < defaultMTU { v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } sz += len(v) } @@ -2311,7 +2311,7 @@ func TestSegmentMerging(t *testing.T) { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2381,7 +2381,7 @@ func TestDelay(t *testing.T) { allData = append(allData, data...) view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2428,7 +2428,7 @@ func TestUndelay(t *testing.T) { for i, data := range allData { view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2512,7 +2512,7 @@ func TestMSSNotDelayed(t *testing.T) { for i, data := range allData { view := buffer.NewViewFromBytes(data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %v", i+1, err) + t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2563,7 +2563,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { copy(view, data) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that data is received in chunks. @@ -2631,7 +2631,7 @@ func TestSetTTL(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { @@ -2639,7 +2639,7 @@ func TestSetTTL(t *testing.T) { } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %s", err) + t.Fatalf("unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -2671,7 +2671,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() @@ -2683,11 +2683,11 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Do 3-way handshake. @@ -2705,7 +2705,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -2794,7 +2794,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { select { case err := <-ch: if err != nil { - t.Fatalf("Error creating endpoint: %v", err) + t.Fatalf("Error creating endpoint: %s", err) } case <-time.After(2 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -2813,7 +2813,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Set the buffer size to a deterministic size so that we can check the @@ -2830,7 +2830,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.WQ.EventUnregister(&we) if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } // Receive SYN packet. @@ -2884,7 +2884,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { select { case <-ch: if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -2899,22 +2899,22 @@ func TestCloseListener(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Close the listener and measure how long it takes. t0 := time.Now() ep.Close() if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %v", diff) + t.Fatalf("Took too long to close: %s", diff) } } @@ -2950,22 +2950,25 @@ loop: case tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } // Expect the state to be StateError and subsequent Reads to fail with HardError. if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got) + t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) } if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -2990,7 +2993,7 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) } } @@ -3013,7 +3016,7 @@ func TestMaxRetransmitsTimeout(t *testing.T) { _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Expect first transmit and MaxRetries retransmits. @@ -3048,7 +3051,10 @@ func TestMaxRetransmitsTimeout(t *testing.T) { ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -3066,7 +3072,7 @@ func TestMaxRTO(t *testing.T) { _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) if err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -3089,6 +3095,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3097,7 +3160,7 @@ func TestFinImmediately(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3140,7 +3203,7 @@ func TestFinRetransmit(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3195,7 +3258,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Write something out, and have it acknowledged. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -3221,7 +3284,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Shutdown, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3268,7 +3331,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { view := buffer.NewView(10) for i := tcp.InitialCwnd; i > 0; i-- { if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } } @@ -3290,7 +3353,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // because the congestion window doesn't allow it. Wait until a // retransmit is received. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3354,7 +3417,7 @@ func TestFinWithPendingData(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -3380,7 +3443,7 @@ func TestFinWithPendingData(t *testing.T) { // Write new data, but don't acknowledge it. if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3396,7 +3459,7 @@ func TestFinWithPendingData(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3441,7 +3504,7 @@ func TestFinWithPartialAck(t *testing.T) { // FIN from the test side. view := buffer.NewView(10) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -3478,7 +3541,7 @@ func TestFinWithPartialAck(t *testing.T) { // Write new data, but don't acknowledge it. if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3494,7 +3557,7 @@ func TestFinWithPartialAck(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3540,20 +3603,20 @@ func TestUpdateListenBacklog(t *testing.T) { var wq waiter.Queue ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Update the backlog with another Listen() on the same endpoint. if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %v", err) + t.Fatalf("Listen failed to update backlog: %s", err) } ep.Close() @@ -3585,7 +3648,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Send some data. Check that it's capped by the window size. view := buffer.NewView(65535) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Check that only data that fits in the scaled window is sent. @@ -3631,18 +3694,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { }) if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) } // Ensure there were no errors during handshake. If these stats have // incremented, then the connection should not have been established. if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) } if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0) } } @@ -3666,10 +3729,10 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c.SendSegment(vv) if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -3770,7 +3833,7 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. @@ -3789,7 +3852,7 @@ func TestReadAfterClosedState(t *testing.T) { ) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Send some data and acknowledge the FIN. @@ -3818,7 +3881,7 @@ func TestReadAfterClosedState(t *testing.T) { time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Wait for receive to be notified. @@ -3853,11 +3916,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -3871,66 +3934,84 @@ func TestReusePort(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() // Second case, an endpoint that was bound and is connecting.. c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } c.EP.Close() // Third case, an endpoint that was bound and is listening. c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } c.EP.Close() c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } } @@ -3939,11 +4020,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } if int(s) != v { - t.Fatalf("got receive buffer size = %v, want = %v", s, v) + t.Fatalf("got receive buffer size = %d, want = %d", s, v) } } @@ -3952,11 +4033,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("GetSockOpt failed: %s", err) } if int(s) != v { - t.Fatalf("got send buffer size = %v, want = %v", s, v) + t.Fatalf("got send buffer size = %d, want = %d", s, v) } } @@ -3969,7 +4050,7 @@ func TestDefaultBufferSizes(t *testing.T) { // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer func() { if ep != nil { @@ -3981,28 +4062,34 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } ep.Close() ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %v", err) } ep.Close() ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) @@ -4018,17 +4105,17 @@ func TestMinMaxBufferSizes(t *testing.T) { // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Set values below the min. @@ -4065,12 +4152,12 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) + t.Errorf("CreateNIC failed: %s", err) } // nicIDPtr is used instead of taking the address of NICID literals, which is @@ -4095,12 +4182,12 @@ func TestBindToDeviceOption(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) + t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) + t.Errorf("GetSockOpt got %s, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { t.Errorf("bindToDevice got %d, want %d", got, want) @@ -4166,12 +4253,12 @@ func TestSelfConnect(t *testing.T) { var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Register for notification, then start connection attempt. @@ -4180,12 +4267,12 @@ func TestSelfConnect(t *testing.T) { defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) } <-notifyCh if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("Connect failed: %v", err) + t.Fatalf("Connect failed: %s", err) } // Write something. @@ -4193,7 +4280,7 @@ func TestSelfConnect(t *testing.T) { view := buffer.NewView(len(data)) copy(view, data) if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } // Read back what was written. @@ -4202,12 +4289,12 @@ func TestSelfConnect(t *testing.T) { rd, _, err := ep.Read(nil) if err != nil { if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } <-notifyCh rd, _, err = ep.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } } @@ -4291,7 +4378,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { } ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps = append(eps, ep) switch network { @@ -4342,7 +4429,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %v", i, err) + t.Fatalf("Bind(%d) failed: %s", i, err) } } want := tcpip.ErrConnectStarted @@ -4350,7 +4437,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { want = tcpip.ErrNoPortAvailable } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) + t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) } }) } @@ -4384,7 +4471,7 @@ func TestPathMTUDiscovery(t *testing.T) { } if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { @@ -4487,7 +4574,7 @@ func TestStackSetCongestionControl(t *testing.T) { var oldCC tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) + t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { @@ -4574,12 +4661,12 @@ func TestEndpointSetCongestionControl(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } var oldCC tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) } if connected { @@ -4587,12 +4674,12 @@ func TestEndpointSetCongestionControl(t *testing.T) { } if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) } got, want := cc, oldCC @@ -4615,7 +4702,7 @@ func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) } } @@ -4657,14 +4744,14 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. view := buffer.NewView(3) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -4744,15 +4831,18 @@ func TestKeepalive(t *testing.T) { ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got) + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) } if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -4854,19 +4944,19 @@ func TestListenBacklogFull(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. // Start listening. listenBacklog := 2 if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } for i := 0; i < listenBacklog; i++ { @@ -4899,7 +4989,7 @@ func TestListenBacklogFull(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4928,7 +5018,7 @@ func TestListenBacklogFull(t *testing.T) { case <-ch: newEP, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -4942,7 +5032,7 @@ func TestListenBacklogFull(t *testing.T) { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } } @@ -5162,19 +5252,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) { var err *tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } // Bind to wildcard. if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } // Test acceptance. // Start listening. listenBacklog := 1 if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the @@ -5240,7 +5330,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { case <-ch: newEP, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5254,7 +5344,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } } @@ -5316,7 +5406,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5450,7 +5540,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { pkt := c.GetPacket() tcpHdr = header.TCP(header.IPv4(pkt).Payload()) if string(tcpHdr.Payload()) != data { - t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } } @@ -5460,20 +5550,20 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } stats := c.Stack().Stats() @@ -5494,7 +5584,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5503,7 +5593,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { } if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) } } @@ -5514,14 +5604,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { stats := c.Stack().Stats() ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } c.EP = ep if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } srcPort := uint16(context.TestPort) @@ -5546,10 +5636,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { time.Sleep(50 * time.Millisecond) if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) } if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) } we, ch := waiter.NewChannelEntry(nil) @@ -5564,7 +5654,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { case <-ch: _, _, err = c.EP.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5579,28 +5669,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { - t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected) + t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected) } if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1) + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) @@ -5617,7 +5707,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { case <-ch: aep, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -5625,25 +5715,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } } if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } ep.Close() // Give worker goroutines time to receive the close notification. time.Sleep(1 * time.Second) if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } // Accepted endpoint remains open when the listen endpoint is closed. if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } } @@ -5663,13 +5753,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Enable auto-tuning. if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -5784,13 +5874,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Enable auto-tuning. if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. @@ -5935,7 +6025,7 @@ func TestDelayEnabled(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } @@ -5946,7 +6036,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del var gotDelayEnabled tcp.DelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } if gotDelayEnabled != wantDelayEnabled { t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) @@ -5954,7 +6044,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) } gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { @@ -6515,10 +6605,10 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want) + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) } if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } } @@ -6715,7 +6805,7 @@ func TestTCPUserTimeout(t *testing.T) { // Send some data and wait before ACKing it. view := buffer.NewView(3) if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) + t.Fatalf("Write failed: %s", err) } next := uint32(c.IRS) + 1 @@ -6765,11 +6855,14 @@ func TestTCPUserTimeout(t *testing.T) { ) if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) } if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -6796,7 +6889,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } // Now receive 1 keepalives, but don't ACK it. @@ -6837,10 +6930,13 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ) if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) } if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) } } @@ -6896,11 +6992,11 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // ack should be sent in response to that. The window was not // zero, but it grew to larger than MSS. if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } // After reading two packets, we surely crossed MSS. See the ack: @@ -6997,13 +7093,13 @@ func TestTCPDeferAccept(t *testing.T) { const tcpDeferAccept = 1 * time.Second if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7026,7 +7122,7 @@ func TestTCPDeferAccept(t *testing.T) { time.Sleep(50 * time.Millisecond) aep, _, err := c.EP.Accept() if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) } aep.Close() @@ -7054,13 +7150,13 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { const tcpDeferAccept = 1 * time.Second if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7094,7 +7190,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { time.Sleep(50 * time.Millisecond) aep, _, err := c.EP.Accept() if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) } aep.Close() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9721f6caf..06fde2a79 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -144,12 +144,12 @@ func New(t *testing.T, mtu uint32) *Context { }) // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) } // Increase minimum RTO in tests to avoid test flakes due to early diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index c70525f27..7981d469b 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { t.timer.Stop() + *t = timer{} } // checkExpiration checks if the given timer has actually expired, it should be diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go new file mode 100644 index 000000000..dbd6dff54 --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sleep" +) + +func TestCleanup(t *testing.T) { + const ( + timerDurationSeconds = 2 + isAssertedTimeoutSeconds = timerDurationSeconds + 1 + ) + + tmr := timer{} + w := sleep.Waker{} + tmr.init(&w) + tmr.enable(timerDurationSeconds * time.Second) + tmr.cleanup() + + if want := (timer{}); tmr != want { + t.Errorf("got tmr = %+v, want = %+v", tmr, want) + } + + // The waker should not be asserted. + for i := 0; i < isAssertedTimeoutSeconds; i++ { + time.Sleep(time.Second) + if w.IsAsserted() { + t.Fatalf("waker asserted unexpectedly") + } + } +} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index df5efbf6a..6e692da07 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,6 +15,8 @@ package udp import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -94,6 +96,7 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int + sndBufSizeMax int state EndpointState route stack.Route `state:"manual"` dstPort uint16 @@ -106,6 +109,7 @@ type endpoint struct { portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool + noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -159,7 +163,7 @@ type multicastMembership struct { } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - return &endpoint{ + e := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -181,10 +185,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue multicastTTL: 1, multicastLoop: true, rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024, + sndBufSizeMax: 32 * 1024, state: StateInitial, uniqueID: s.UniqueID(), } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + + return e } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -215,7 +232,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -513,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -537,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -590,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -611,8 +640,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.mu.Unlock() case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) + } + + if v < rs.Min { + v = rs.Min + } + if v > rs.Max { + v = rs.Max + } + + e.mu.Lock() + e.rcvBufSizeMax = v + e.mu.Unlock() + return nil case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) + } + + if v < ss.Min { + v = ss.Min + } + if v > ss.Max { + v = ss.Max + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil } return nil @@ -752,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -774,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -843,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -861,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { case tcpip.SendBufferSizeOption: e.mu.Lock() - v := e.sndBufSize + v := e.sndBufSizeMax e.mu.Unlock() return v, nil @@ -908,7 +985,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -922,8 +999,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Length: length, }) - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -996,7 +1077,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } e.state = StateInitial @@ -1147,7 +1228,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return id, e.bindToDevice, err } @@ -1157,7 +1238,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err @@ -1299,10 +1380,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - e.rcvMu.Lock() + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } + e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1343,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 4218e7d03..0e7464e3a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -32,9 +32,24 @@ import ( const ( // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4KiB bytes. + + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 32 << 10 // 32KiB + + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 32 << 10 // 32KiB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MiB ) -type protocol struct{} +type protocol struct { +} // Number returns the udp protocol number. func (*protocol) Number() tcpip.TransportProtocolNumber { @@ -183,12 +198,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 313a3f117..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio wep = sniffer.New(ep) } if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } s.SetRouteTable([]tcpip.Route{ @@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } else { - c.injectV6Packet(payload, &h, true /* valid */) + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } } -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: l, + Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. @@ -439,17 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -483,11 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } func newPayload() []byte { @@ -509,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() @@ -643,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } } @@ -654,19 +661,19 @@ func TestBindReservedPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } addr, err := c.ep.GetLocalAddress() if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) + t.Fatalf("GetLocalAddress failed: %s", err) } // We can't bind the address reserved by the connected endpoint above. { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { @@ -677,7 +684,7 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint @@ -687,7 +694,7 @@ func TestBindReservedPort(t *testing.T) { } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() @@ -697,11 +704,11 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() } @@ -714,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -729,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { // Bind to v4 mapped wildcard. if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -744,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -759,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -796,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) { h := unicastV4.header4Tuple(incoming) h.srcAddr = h.dstAddr - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -817,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -880,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -955,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... payload := buffer.View(newPayload()) n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { - c.t.Fatalf("Write failed: %v", err) + c.t.Fatalf("Write failed: %s", err) } if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) @@ -1005,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } p := testDualWrite(c) @@ -1022,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV6) @@ -1043,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV4in6) @@ -1070,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Write to v6 address. @@ -1085,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV6) @@ -1099,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV4) @@ -1234,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testRead(c, unicastV4) @@ -1259,6 +1323,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1506,12 +1594,12 @@ func TestMulticastInterfaceOption(t *testing.T) { Port: stackPort, } if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } } if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOpt failed: %s", err) } // Verify multicast interface addr and NIC were set correctly. @@ -1519,7 +1607,7 @@ func TestMulticastInterfaceOption(t *testing.T) { ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) + c.t.Fatalf("GetSockOpt failed: %s", err) } if ifoptGot != ifoptWant { c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) @@ -1691,7 +1779,7 @@ func TestV6UnknownDestination(t *testing.T) { } // TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. +// global and endpoint stats are incremented. func TestIncrementMalformedPacketsReceived(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1699,20 +1787,27 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } payload := newPayload() - c.t.Helper() h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) + buf := c.buildV6Packet(payload, &h) - var want uint64 = 1 + // Invalidate the UDP header length field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) } if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -1728,7 +1823,6 @@ func TestShortHeader(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - c.t.Helper() h := unicastV6.header4Tuple(incoming) // Allocate a buffer for an IPv6 and too-short UDP header. @@ -1768,6 +1862,199 @@ func TestShortHeader(t *testing.T) { } } +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header checksum field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { @@ -1778,15 +2065,15 @@ func TestShutdownRead(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingRead(c, unicastV6, true /* expectReadError */) @@ -1809,11 +2096,11 @@ func TestShutdownWrite(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go index 8fed29ff5..70945f234 100644 --- a/pkg/test/criutil/criutil.go +++ b/pkg/test/criutil/criutil.go @@ -22,6 +22,9 @@ import ( "fmt" "os" "os/exec" + "path" + "regexp" + "strconv" "strings" "time" @@ -33,28 +36,44 @@ import ( type Crictl struct { logger testutil.Logger endpoint string + runpArgs []string cleanup []func() } -// resolvePath attempts to find binary paths. It may set the path to invalid, +// ResolvePath attempts to find binary paths. It may set the path to invalid, // which will cause the execution to fail with a sensible error. -func resolvePath(executable string) string { +func ResolvePath(executable string) string { + runtime, err := dockerutil.RuntimePath() + if err == nil { + // Check first the directory of the runtime itself. + if dir := path.Dir(runtime); dir != "" && dir != "." { + guess := path.Join(dir, executable) + if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 { + return guess + } + } + } + + // Try to find via the path. guess, err := exec.LookPath(executable) - if err != nil { - guess = fmt.Sprintf("/usr/local/bin/%s", executable) + if err == nil { + return guess } - return guess + + // Return a default path. + return fmt.Sprintf("/usr/local/bin/%s", executable) } // NewCrictl returns a Crictl configured with a timeout and an endpoint over // which it will talk to containerd. -func NewCrictl(logger testutil.Logger, endpoint string) *Crictl { +func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl { // Attempt to find the executable, but don't bother propagating the // error at this point. The first command executed will return with a // binary not found error. return &Crictl{ logger: logger, endpoint: endpoint, + runpArgs: runpArgs, } } @@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() { } // RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { - podID, err := cc.run("runp", sbSpecFile) +func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) { + podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile) if err != nil { return "", fmt.Errorf("runp failed: %v", err) } @@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { // Create creates a container within a sandbox. It corresponds to `crictl // create`. func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - podID, err := cc.run("create", podID, contSpecFile, sbSpecFile) + // In version 1.16.0, crictl annoying starting attempting to pull the + // container, even if it was already available locally. We therefore + // need to parse the version and add an appropriate --no-pull argument + // since the image has already been loaded locally. + out, err := cc.run("-v") + if err != nil { + return "", err + } + r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(out) + if len(vs) != 4 { + return "", fmt.Errorf("crictl -v had unexpected output: %s", out) + } + major, err := strconv.ParseUint(vs[1], 10, 64) if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + + args := []string{"create"} + if (major == 1 && minor >= 16) || major > 1 { + args = append(args, "--no-pull") + } + args = append(args, podID) + args = append(args, contSpecFile) + args = append(args, sbSpecFile) + + podID, err = cc.run(args...) + if err != nil { + time.Sleep(10 * time.Minute) // XXX return "", fmt.Errorf("create failed: %v", err) } + // Strip the trailing newline from crictl output. return strings.TrimSpace(podID), nil } @@ -179,7 +230,7 @@ func (cc *Crictl) Import(image string) error { // be pushing a lot of bytes in order to import the image. The connect // timeout stays the same and is inherited from the Crictl instance. cmd := testutil.Command(cc.logger, - resolvePath("ctr"), + ResolvePath("ctr"), fmt.Sprintf("--connect-timeout=%s", 30*time.Second), fmt.Sprintf("--address=%s", cc.endpoint), "-n", "k8s.io", "images", "import", "-") @@ -260,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error { // StartPodAndContainer starts a sandbox and container in that sandbox. It // returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { +func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) { if err := cc.Import(image); err != nil { return "", "", err } @@ -277,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, } cc.cleanup = append(cc.cleanup, cleanup) - podID, err := cc.RunPod(sbSpecFile) + podID, err := cc.RunPod(runtime, sbSpecFile) if err != nil { return "", "", err } @@ -307,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error { // run runs crictl with the given args. func (cc *Crictl) run(args ...string) (string, error) { defaultArgs := []string{ - resolvePath("crictl"), + ResolvePath("crictl"), "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), } diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7c8758e35..a5e84658a 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -1,14 +1,42 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "dockerutil", testonly = 1, - srcs = ["dockerutil.go"], + srcs = [ + "container.go", + "dockerutil.go", + "exec.go", + "network.go", + "profile.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/test/testutil", - "@com_github_kr_pty//:go_default_library", + "@com_github_docker_docker//api/types:go_default_library", + "@com_github_docker_docker//api/types/container:go_default_library", + "@com_github_docker_docker//api/types/mount:go_default_library", + "@com_github_docker_docker//api/types/network:go_default_library", + "@com_github_docker_docker//client:go_default_library", + "@com_github_docker_docker//pkg/stdcopy:go_default_library", + "@com_github_docker_go_connections//nat:go_default_library", + ], +) + +go_test( + name = "profile_test", + size = "large", + srcs = [ + "profile_test.go", + ], + library = ":dockerutil", + tags = [ + # Requires docker and runsc to be configured before test runs. + # Also requires the test to be run as root. + "manual", + "local", ], + visibility = ["//:sandbox"], ) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md new file mode 100644 index 000000000..870292096 --- /dev/null +++ b/pkg/test/dockerutil/README.md @@ -0,0 +1,86 @@ +# dockerutil + +This package is for creating and controlling docker containers for testing +runsc, gVisor's docker/kubernetes binary. A simple test may look like: + +``` + func TestSuperCool(t *testing.T) { + ctx := context.Background() + c := dockerutil.MakeContainer(ctx, t) + got, err := c.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine" + }, "echo", "super cool") + if err != nil { + t.Fatalf("err was not nil: %v", err) + } + want := "super cool" + if !strings.Contains(got, want){ + t.Fatalf("want: %s, got: %s", want, got) + } + } +``` + +For further examples, see many of our end to end tests elsewhere in the repo, +such as those in //test/e2e or benchmarks at //test/benchmarks. + +dockerutil uses the "official" docker golang api, which is +[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil +is a thin wrapper around this API, allowing desired new use cases to be easily +implemented. + +## Profiling + +dockerutil is capable of generating profiles. Currently, the only option is to +use pprof profiles generated by `runsc debug`. The profiler will generate Block, +CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or + `--vfs2`. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles run: + +``` +make sudo TARGETS=//path/to:target \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` + +Container name in most tests and benchmarks in gVisor is usually the test name +and some random characters like so: +`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` + +Profiling requires root as runsc debug inspects running containers in /var/run +among other things. + +### Writing for Profiling + +The below shows an example of using profiles with dockerutil. + +``` +func TestSuperCool(t *testing.T){ + ctx := context.Background() + // profiled and using runtime from dockerutil.runtime flag + profiled := MakeContainer() + + // not profiled and using runtime runc + native := MakeNativeContainer() + + err := profiled.Spawn(ctx, RunOpts{ + Image: "some/image", + }, "sleep", "100000") + // profiling has begun here + ... + expensive setup that I don't want to profile. + ... + profiled.RestartProfiles() + // profiled activity +} +``` + +In the above example, `profiled` would be profiled and `native` would not. The +call to `RestartProfiles()` restarts the clock on profiling. This is useful if +the main activity being tested is done with `docker exec` or `container.Spawn()` +followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go new file mode 100644 index 000000000..b59503188 --- /dev/null +++ b/pkg/test/dockerutil/container.go @@ -0,0 +1,543 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "regexp" + "strconv" + "strings" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Container represents a Docker Container allowing +// user to configure and control as one would with the 'docker' +// client. Container is backed by the offical golang docker API. +// See: https://pkg.go.dev/github.com/docker/docker. +type Container struct { + Name string + runtime string + + logger testutil.Logger + client *client.Client + id string + mounts []mount.Mount + links []string + copyErr error + cleanups []func() + + // Profiles are profiles added to this container. They contain methods + // that are run after Creation, Start, and Cleanup of this Container, along + // a handle to restart the profile. Generally, tests/benchmarks using + // profiles need to run as root. + profiles []Profile + + // Stores streams attached to the container. Used by WaitForOutputSubmatch. + streams types.HijackedResponse + + // stores previously read data from the attached streams. + streamBuf bytes.Buffer +} + +// RunOpts are options for running a container. +type RunOpts struct { + // Image is the image relative to images/. This will be mangled + // appropriately, to ensure that only first-party images are used. + Image string + + // Memory is the memory limit in bytes. + Memory int + + // Cpus in which to allow execution. ("0", "1", "0-2"). + CpusetCpus string + + // Ports are the ports to be allocated. + Ports []int + + // WorkDir sets the working directory. + WorkDir string + + // ReadOnly sets the read-only flag. + ReadOnly bool + + // Env are additional environment variables. + Env []string + + // User is the user to use. + User string + + // Privileged enables privileged mode. + Privileged bool + + // CapAdd are the extra set of capabilities to add. + CapAdd []string + + // CapDrop are the extra set of capabilities to drop. + CapDrop []string + + // Mounts is the list of directories/files to be mounted inside the container. + Mounts []mount.Mount + + // Links is the list of containers to be connected to the container. + Links []string +} + +// MakeContainer sets up the struct for a Docker container. +// +// Names of containers will be unique. +// Containers will check flags for profiling requests. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := MakeNativeContainer(ctx, logger) + c.runtime = *runtime + if p := MakePprofFromFlags(c); p != nil { + c.AddProfile(p) + } + return c +} + +// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native +// containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return nil + } + client.NegotiateAPIVersion(ctx) + return &Container{ + logger: logger, + Name: name, + runtime: "", + client: client, + } +} + +// AddProfile adds a profile to this container. +func (c *Container) AddProfile(p Profile) { + c.profiles = append(c.profiles, p) +} + +// RestartProfiles calls Restart on all profiles for this container. +func (c *Container) RestartProfiles() error { + for _, profile := range c.profiles { + if err := profile.Restart(c); err != nil { + return err + } + } + return nil +} + +// Spawn is analogous to 'docker run -d'. +func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return err + } + return c.Start(ctx) +} + +// SpawnProcess is analogous to 'docker run -it'. It returns a process +// which represents the root process. +func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) { + config, hostconf, netconf := c.ConfigsFrom(r, args...) + config.Tty = true + config.OpenStdin = true + + if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil { + return Process{}, err + } + + if err := c.Start(ctx); err != nil { + return Process{}, err + } + + return Process{container: c, conn: c.streams}, nil +} + +// Run is analogous to 'docker run'. +func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return "", err + } + + if err := c.Start(ctx); err != nil { + return "", err + } + + if err := c.Wait(ctx); err != nil { + return "", err + } + + return c.Logs(ctx) +} + +// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom' +// and Start. +func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) { + return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{} +} + +// MakeLink formats a link to add to a RunOpts. +func (c *Container) MakeLink(target string) string { + return fmt.Sprintf("%s:%s", c.Name, target) +} + +// CreateFrom creates a container from the given configs. +func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + return c.create(ctx, conf, hostconf, netconf) +} + +// Create is analogous to 'docker create'. +func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { + return c.create(ctx, c.config(r, args), c.hostConfig(r), nil) +} + +func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) + if err != nil { + return err + } + c.id = cont.ID + for _, profile := range c.profiles { + if err := profile.OnCreate(c); err != nil { + return fmt.Errorf("OnCreate method failed with: %v", err) + } + } + return nil +} + +func (c *Container) config(r RunOpts, args []string) *container.Config { + ports := nat.PortSet{} + for _, p := range r.Ports { + port := nat.Port(fmt.Sprintf("%d", p)) + ports[port] = struct{}{} + } + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + + return &container.Config{ + Image: testutil.ImageByName(r.Image), + Cmd: args, + ExposedPorts: ports, + Env: env, + WorkingDir: r.WorkDir, + User: r.User, + } +} + +func (c *Container) hostConfig(r RunOpts) *container.HostConfig { + c.mounts = append(c.mounts, r.Mounts...) + + return &container.HostConfig{ + Runtime: c.runtime, + Mounts: c.mounts, + PublishAllPorts: true, + Links: r.Links, + CapAdd: r.CapAdd, + CapDrop: r.CapDrop, + Privileged: r.Privileged, + ReadonlyRootfs: r.ReadOnly, + Resources: container.Resources{ + Memory: int64(r.Memory), // In bytes. + CpusetCpus: r.CpusetCpus, + }, + } +} + +// Start is analogous to 'docker start'. +func (c *Container) Start(ctx context.Context) error { + + // Open a connection to the container for parsing logs and for TTY. + streams, err := c.client.ContainerAttach(ctx, c.id, + types.ContainerAttachOptions{ + Stream: true, + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + return fmt.Errorf("failed to connect to container: %v", err) + } + + c.streams = streams + c.cleanups = append(c.cleanups, func() { + c.streams.Close() + }) + if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { + return fmt.Errorf("ContainerStart failed: %v", err) + } + for _, profile := range c.profiles { + if err := profile.OnStart(c); err != nil { + return fmt.Errorf("OnStart method failed: %v", err) + } + } + return nil +} + +// Stop is analogous to 'docker stop'. +func (c *Container) Stop(ctx context.Context) error { + return c.client.ContainerStop(ctx, c.id, nil) +} + +// Pause is analogous to'docker pause'. +func (c *Container) Pause(ctx context.Context) error { + return c.client.ContainerPause(ctx, c.id) +} + +// Unpause is analogous to 'docker unpause'. +func (c *Container) Unpause(ctx context.Context) error { + return c.client.ContainerUnpause(ctx, c.id) +} + +// Checkpoint is analogous to 'docker checkpoint'. +func (c *Container) Checkpoint(ctx context.Context, name string) error { + return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true}) +} + +// Restore is analogous to 'docker start --checkname [name]'. +func (c *Container) Restore(ctx context.Context, name string) error { + return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name}) +} + +// Logs is analogous 'docker logs'. +func (c *Container) Logs(ctx context.Context) (string, error) { + var out bytes.Buffer + err := c.logs(ctx, &out, &out) + return out.String(), err +} + +func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error { + opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true} + writer, err := c.client.ContainerLogs(ctx, c.id, opts) + if err != nil { + return err + } + defer writer.Close() + _, err = stdcopy.StdCopy(stdout, stderr, writer) + + return err +} + +// ID returns the container id. +func (c *Container) ID() string { + return c.id +} + +// SandboxPid returns the container's pid. +func (c *Container) SandboxPid(ctx context.Context) (int, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, err + } + return resp.ContainerJSONBase.State.Pid, nil +} + +// FindIP returns the IP address of the container. +func (c *Container) FindIP(ctx context.Context) (net.IP, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return nil, err + } + + ip := net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) + if ip == nil { + return net.IP{}, fmt.Errorf("invalid IP: %q", ip) + } + return ip, nil +} + +// FindPort returns the host port that is mapped to 'sandboxPort'. +func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) { + desc, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, fmt.Errorf("error retrieving port: %v", err) + } + + format := fmt.Sprintf("%d/tcp", sandboxPort) + ports, ok := desc.NetworkSettings.Ports[nat.Port(format)] + if !ok { + return -1, fmt.Errorf("error retrieving port: %v", err) + + } + + port, err := strconv.Atoi(ports[0].HostPort) + if err != nil { + return -1, fmt.Errorf("error parsing port %q: %v", port, err) + } + return port, nil +} + +// CopyFiles copies in and mounts the given files. They are always ReadOnly. +func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { + dir, err := ioutil.TempDir("", c.Name) + if err != nil { + c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) + return + } + c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) }) + if err := os.Chmod(dir, 0755); err != nil { + c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) + return + } + for _, name := range sources { + src, err := testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) + return + } + dst := path.Join(dir, path.Base(name)) + if err := testutil.Copy(src, dst); err != nil { + c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) + return + } + c.logger.Logf("copy: %s -> %s", src, dst) + } + opts.Mounts = append(opts.Mounts, mount.Mount{ + Type: mount.TypeBind, + Source: dir, + Target: target, + ReadOnly: false, + }) +} + +// Status inspects the container returns its status. +func (c *Container) Status(ctx context.Context) (types.ContainerState, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return types.ContainerState{}, err + } + return *resp.State, err +} + +// Wait waits for the container to exit. +func (c *Container) Wait(ctx context.Context) error { + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitTimeout waits for the container to exit with a timeout. +func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error { + timeoutChan := time.After(timeout) + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + case <-timeoutChan: + return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds()) + } +} + +// WaitForOutput searches container logs for pattern and returns or timesout. +func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) { + matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", fmt.Errorf("didn't find pattern %s logs", pattern) + } + return matches[0], nil +} + +// WaitForOutputSubmatch searches container logs for the given +// pattern or times out. It returns any regexp submatches as well. +func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { + re := regexp.MustCompile(pattern) + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + + for exp := time.Now().Add(timeout); time.Now().Before(exp); { + c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader) + + if err != nil { + // check that it wasn't a timeout + if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() { + return nil, err + } + } + + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + } + + return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String()) +} + +// Kill kills the container. +func (c *Container) Kill(ctx context.Context) error { + return c.client.ContainerKill(ctx, c.id, "") +} + +// Remove is analogous to 'docker rm'. +func (c *Container) Remove(ctx context.Context) error { + // Remove the image. + remove := types.ContainerRemoveOptions{ + RemoveVolumes: c.mounts != nil, + RemoveLinks: c.links != nil, + Force: true, + } + return c.client.ContainerRemove(ctx, c.Name, remove) +} + +// CleanUp kills and deletes the container (best effort). +func (c *Container) CleanUp(ctx context.Context) { + // Execute profile cleanups before the container goes down. + for _, profile := range c.profiles { + profile.OnCleanUp(c) + } + // Forget profiles. + c.profiles = nil + // Kill the container. + if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { + // Just log; can't do anything here. + c.logger.Logf("error killing container %q: %v", c.Name, err) + } + // Remove the image. + if err := c.Remove(ctx); err != nil { + c.logger.Logf("error removing container %q: %v", c.Name, err) + } + // Forget all mounts. + c.mounts = nil + // Execute all cleanups. + for _, c := range c.cleanups { + c() + } + c.cleanups = nil +} diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index c45d2ecbc..5a9dd8bd8 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -22,17 +22,11 @@ import ( "io" "io/ioutil" "log" - "net" - "os" "os/exec" - "path" "regexp" "strconv" - "strings" - "syscall" "time" - "github.com/kr/pty" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -49,6 +43,26 @@ var ( // config is the default Docker daemon configuration path. config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") + + // The following flags are for the "pprof" profiler tool. + + // pprofBaseDir allows the user to change the directory to which profiles are + // written. By default, profiles will appear under: + // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + + // duration is the max duration `runsc debug` will run and capture profiles. + // If the container's clean up method is called prior to duration, the + // profiling process will be killed. + duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + + // The below flags enable each type of profile. Multiple profiles can be + // enabled for each run. + pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") + pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") + pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug") + pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") + pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) // EnsureSupportedDockerVersion checks if correct docker is installed. @@ -127,570 +141,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error { return cmd.Run() } -// MountMode describes if the mount should be ro or rw. -type MountMode int - -const ( - // ReadOnly is what the name says. - ReadOnly MountMode = iota - // ReadWrite is what the name says. - ReadWrite -) - -// String returns the mount mode argument for this MountMode. -func (m MountMode) String() string { - switch m { - case ReadOnly: - return "ro" - case ReadWrite: - return "rw" - } - panic(fmt.Sprintf("invalid mode: %d", m)) -} - -// DockerNetwork contains the name of a docker network. -type DockerNetwork struct { - logger testutil.Logger - Name string - Subnet *net.IPNet - containers []*Docker -} - -// NewDockerNetwork sets up the struct for a Docker network. Names of networks -// will be unique. -func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { - return &DockerNetwork{ - logger: logger, - Name: testutil.RandomID(logger.Name()), - } -} - -// Create calls 'docker network create'. -func (n *DockerNetwork) Create(args ...string) error { - a := []string{"docker", "network", "create"} - if n.Subnet != nil { - a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) - } - a = append(a, args...) - a = append(a, n.Name) - return testutil.Command(n.logger, a...).Run() -} - -// Connect calls 'docker network connect' with the arguments provided. -func (n *DockerNetwork) Connect(container *Docker, args ...string) error { - a := []string{"docker", "network", "connect"} - a = append(a, args...) - a = append(a, n.Name, container.Name) - if err := testutil.Command(n.logger, a...).Run(); err != nil { - return err - } - n.containers = append(n.containers, container) - return nil -} - -// Cleanup cleans up the docker network and all the containers attached to it. -func (n *DockerNetwork) Cleanup() error { - for _, c := range n.containers { - // Don't propagate the error, it might be that the container - // was already cleaned up. - if err := c.Kill(); err != nil { - n.logger.Logf("unable to kill container during cleanup: %s", err) - } - } - - if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { - return err - } - return nil -} - -// Docker contains the name and the runtime of a docker container. -type Docker struct { - logger testutil.Logger - Runtime string - Name string - copyErr error - mounts []string - cleanups []func() -} - -// MakeDocker sets up the struct for a Docker container. -// -// Names of containers will be unique. -func MakeDocker(logger testutil.Logger) *Docker { - // Slashes are not allowed in container names. - name := testutil.RandomID(logger.Name()) - name = strings.ReplaceAll(name, "/", "-") - - return &Docker{ - logger: logger, - Name: name, - Runtime: *runtime, - } -} - -// Mount mounts the given source and makes it available in the container. -func (d *Docker) Mount(target, source string, mode MountMode) { - d.mounts = append(d.mounts, fmt.Sprintf("-v=%s:%s:%v", source, target, mode)) -} - -// CopyFiles copies in and mounts the given files. They are always ReadOnly. -func (d *Docker) CopyFiles(target string, sources ...string) { - dir, err := ioutil.TempDir("", d.Name) - if err != nil { - d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) - return - } - d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) }) - if err := os.Chmod(dir, 0755); err != nil { - d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) - return - } - for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - return - } - d.logger.Logf("copy: %s -> %s", src, dst) - } - d.Mount(target, dir, ReadOnly) -} - -// Link links the given target. -func (d *Docker) Link(target string, source *Docker) { - d.mounts = append(d.mounts, fmt.Sprintf("--link=%s:%s", source.Name, target)) -} - -// RunOpts are options for running a container. -type RunOpts struct { - // Image is the image relative to images/. This will be mangled - // appropriately, to ensure that only first-party images are used. - Image string - - // Memory is the memory limit in kB. - Memory int - - // Ports are the ports to be allocated. - Ports []int - - // WorkDir sets the working directory. - WorkDir string - - // ReadOnly sets the read-only flag. - ReadOnly bool - - // Env are additional environment variables. - Env []string - - // User is the user to use. - User string - - // Privileged enables privileged mode. - Privileged bool - - // CapAdd are the extra set of capabilities to add. - CapAdd []string - - // CapDrop are the extra set of capabilities to drop. - CapDrop []string - - // Pty indicates that a pty will be allocated. If this is non-nil, then - // this will run after start-up with the *exec.Command and Pty file - // passed in to the function. - Pty func(*exec.Cmd, *os.File) - - // Foreground indicates that the container should be run in the - // foreground. If this is true, then the output will be available as a - // return value from the Run function. - Foreground bool - - // Extra are extra arguments that may be passed. - Extra []string -} - -// args returns common arguments. -// -// Note that this does not define the complete behavior. -func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { - isExec := command == "exec" - isRun := command == "run" - - if isRun || isExec { - rv = append(rv, "-i") - } - if r.Pty != nil { - rv = append(rv, "-t") - } - if r.User != "" { - rv = append(rv, fmt.Sprintf("--user=%s", r.User)) - } - if r.Privileged { - rv = append(rv, "--privileged") - } - for _, c := range r.CapAdd { - rv = append(rv, fmt.Sprintf("--cap-add=%s", c)) - } - for _, c := range r.CapDrop { - rv = append(rv, fmt.Sprintf("--cap-drop=%s", c)) - } - for _, e := range r.Env { - rv = append(rv, fmt.Sprintf("--env=%s", e)) - } - if r.WorkDir != "" { - rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir)) - } - if !isExec { - if r.Memory != 0 { - rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory)) - } - for _, p := range r.Ports { - rv = append(rv, fmt.Sprintf("--publish=%d", p)) - } - if r.ReadOnly { - rv = append(rv, fmt.Sprintf("--read-only")) - } - if len(p) > 0 { - rv = append(rv, "--entrypoint=") - } - } - - // Always attach the test environment & Extra. - rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name)) - rv = append(rv, r.Extra...) - - // Attach necessary bits. - if isExec { - rv = append(rv, d.Name) - } else { - rv = append(rv, d.mounts...) - if len(d.Runtime) > 0 { - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) - } - rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) - rv = append(rv, testutil.ImageByName(r.Image)) - } - - // Attach other arguments. - rv = append(rv, p...) - return rv -} - -// run runs a complete command. -func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) { - if d.copyErr != nil { - return "", d.copyErr - } - basicArgs := []string{"docker"} - if command == "spawn" { - command = "run" - basicArgs = append(basicArgs, command) - basicArgs = append(basicArgs, "-d") - } else { - basicArgs = append(basicArgs, command) - } - customArgs := d.argsFor(&r, command, p) - cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...) - if r.Pty != nil { - // If allocating a terminal, then we just ignore the output - // from the command. - ptmx, err := pty.Start(cmd.Cmd) - if err != nil { - return "", err - } - defer cmd.Wait() // Best effort. - r.Pty(cmd.Cmd, ptmx) - } else { - // Can't support PTY or streaming. - out, err := cmd.CombinedOutput() - return string(out), err - } - return "", nil -} - -// Create calls 'docker create' with the arguments provided. -func (d *Docker) Create(r RunOpts, args ...string) error { - out, err := d.run(r, "create", args...) - if strings.Contains(out, "Unable to find image") { - return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err) - } - return err -} - -// Start calls 'docker start'. -func (d *Docker) Start() error { - return testutil.Command(d.logger, "docker", "start", d.Name).Run() -} - -// Stop calls 'docker stop'. -func (d *Docker) Stop() error { - return testutil.Command(d.logger, "docker", "stop", d.Name).Run() -} - -// Run calls 'docker run' with the arguments provided. -func (d *Docker) Run(r RunOpts, args ...string) (string, error) { - return d.run(r, "run", args...) -} - -// Spawn starts the container and detaches. -func (d *Docker) Spawn(r RunOpts, args ...string) error { - _, err := d.run(r, "spawn", args...) - return err -} - -// Logs calls 'docker logs'. -func (d *Docker) Logs() (string, error) { - // Don't capture the output; since it will swamp the logs. - out, err := exec.Command("docker", "logs", d.Name).CombinedOutput() - return string(out), err -} - -// Exec calls 'docker exec' with the arguments provided. -func (d *Docker) Exec(r RunOpts, args ...string) (string, error) { - return d.run(r, "exec", args...) -} - -// Pause calls 'docker pause'. -func (d *Docker) Pause() error { - return testutil.Command(d.logger, "docker", "pause", d.Name).Run() -} - -// Unpause calls 'docker pause'. -func (d *Docker) Unpause() error { - return testutil.Command(d.logger, "docker", "unpause", d.Name).Run() -} - -// Checkpoint calls 'docker checkpoint'. -func (d *Docker) Checkpoint(name string) error { - return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run() -} - -// Restore calls 'docker start --checkname [name]'. -func (d *Docker) Restore(name string) error { - return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run() -} - -// Kill calls 'docker kill'. -func (d *Docker) Kill() error { - // Skip logging this command, it will likely be an error. - out, err := exec.Command("docker", "kill", d.Name).CombinedOutput() - if err != nil && !strings.Contains(string(out), "is not running") { - return err - } - return nil -} - -// Remove calls 'docker rm'. -func (d *Docker) Remove() error { - return testutil.Command(d.logger, "docker", "rm", d.Name).Run() -} - -// CleanUp kills and deletes the container (best effort). -func (d *Docker) CleanUp() { - // Kill the container. - if err := d.Kill(); err != nil { - // Just log; can't do anything here. - d.logger.Logf("error killing container %q: %v", d.Name, err) - } - // Remove the image. - if err := d.Remove(); err != nil { - d.logger.Logf("error removing container %q: %v", d.Name, err) - } - // Forget all mounts. - d.mounts = nil - // Execute all cleanups. - for _, c := range d.cleanups { - c() - } - d.cleanups = nil -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. This calls -// docker to allocate a free port in the host and prevent conflicts. -func (d *Docker) FindPort(sandboxPort int) (int, error) { - format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort) - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", out, err) - } - return port, nil -} - -// FindIP returns the IP address of the container. -func (d *Docker) FindIP() (net.IP, error) { - const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return net.IP{}, fmt.Errorf("error retrieving IP: %v", err) - } - ip := net.ParseIP(strings.TrimSpace(string(out))) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", string(out)) - } - return ip, nil -} - -// A NetworkInterface is container's network interface information. -type NetworkInterface struct { - IPv4 net.IP - MAC net.HardwareAddr -} - -// ListNetworks returns the network interfaces of the container, keyed by -// Docker network name. -func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { - const format = `{{json .NetworkSettings.Networks}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) - } - - networks := map[string]map[string]string{} - if err := json.Unmarshal(out, &networks); err != nil { - return nil, fmt.Errorf("error decoding network interfaces: %w", err) - } - - interfaces := map[string]NetworkInterface{} - for name, iface := range networks { - var netface NetworkInterface - - rawIP := strings.TrimSpace(iface["IPAddress"]) - if rawIP != "" { - ip := net.ParseIP(rawIP) - if ip == nil { - return nil, fmt.Errorf("invalid IP: %q", rawIP) - } - // Docker's IPAddress field is IPv4. The IPv6 address - // is stored in the GlobalIPv6Address field. - netface.IPv4 = ip - } - - rawMAC := strings.TrimSpace(iface["MacAddress"]) - if rawMAC != "" { - mac, err := net.ParseMAC(rawMAC) - if err != nil { - return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) - } - netface.MAC = mac - } - - interfaces[name] = netface - } - - return interfaces, nil -} - -// SandboxPid returns the PID to the sandbox process. -func (d *Docker) SandboxPid() (int, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving pid: %v", err) - } - pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing pid %q: %v", out, err) - } - return pid, nil -} - -// ID returns the container ID. -func (d *Docker) ID() (string, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput() - if err != nil { - return "", fmt.Errorf("error retrieving ID: %v", err) - } - return strings.TrimSpace(string(out)), nil -} - -// Wait waits for container to exit, up to the given timeout. Returns error if -// wait fails or timeout is hit. Returns the application return code otherwise. -// Note that the application may have failed even if err == nil, always check -// the exit code. -func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) { - timeoutChan := time.After(timeout) - waitChan := make(chan (syscall.WaitStatus)) - errChan := make(chan (error)) - - go func() { - out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput() - if err != nil { - errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err) - } - exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err) - } - waitChan <- syscall.WaitStatus(uint32(exit)) - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return syscall.WaitStatus(1), err - case <-timeoutChan: - return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name) - } -} - -// WaitForOutput calls 'docker logs' to retrieve containers output and searches -// for the given pattern. -func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) { - matches, err := d.WaitForOutputSubmatch(pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", nil - } - return matches[0], nil -} - -// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and -// searches for the given pattern. It returns any regexp submatches as well. -func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) { - re := regexp.MustCompile(pattern) - var ( - lastOut string - stopped bool - ) - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - out, err := d.Logs() - if err != nil { - return nil, err - } - if out != lastOut { - if lastOut == "" { - d.logger.Logf("output (start): %s", out) - } else if strings.HasPrefix(out, lastOut) { - d.logger.Logf("output (contn): %s", out[len(lastOut):]) - } else { - d.logger.Logf("output (trunc): %s", out) - } - lastOut = out // Save for future. - if matches := re.FindStringSubmatch(lastOut); matches != nil { - return matches, nil // Success! - } - } else if stopped { - // The sandbox stopped and we looked at the - // logs at least once since determining that. - return nil, fmt.Errorf("no longer running: %v", err) - } else if pid, err := d.SandboxPid(); pid == 0 || err != nil { - // The sandbox may have stopped, but it's - // possible that it has emitted the terminal - // line between the last call to Logs and here. - stopped = true - } - time.Sleep(100 * time.Millisecond) - } - return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut) +// Runtime returns the value of the flag runtime. +func Runtime() string { + return *runtime } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go new file mode 100644 index 000000000..4c739c9e9 --- /dev/null +++ b/pkg/test/dockerutil/exec.go @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/pkg/stdcopy" +) + +// ExecOpts holds arguments for Exec calls. +type ExecOpts struct { + // Env are additional environment variables. + Env []string + + // Privileged enables privileged mode. + Privileged bool + + // User is the user to use. + User string + + // Enables Tty and stdin for the created process. + UseTTY bool + + // WorkDir is the working directory of the process. + WorkDir string +} + +// Exec creates a process inside the container. +func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) { + p, err := c.doExec(ctx, opts, args) + if err != nil { + return "", err + } + + if exitStatus, err := p.WaitExitStatus(ctx); err != nil { + return "", err + } else if exitStatus != 0 { + out, _ := p.Logs() + return out, fmt.Errorf("process terminated with status: %d", exitStatus) + } + + return p.Logs() +} + +// ExecProcess creates a process inside the container and returns a process struct +// for the caller to use. +func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) { + return c.doExec(ctx, opts, args) +} + +func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) { + config := c.execConfig(r, args) + resp, err := c.client.ContainerExecCreate(ctx, c.id, config) + if err != nil { + return Process{}, fmt.Errorf("exec create failed with err: %v", err) + } + + hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{}) + if err != nil { + return Process{}, fmt.Errorf("exec attach failed with err: %v", err) + } + + if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil { + hijack.Close() + return Process{}, fmt.Errorf("exec start failed with err: %v", err) + } + + return Process{ + container: c, + execid: resp.ID, + conn: hijack, + }, nil +} + +func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + return types.ExecConfig{ + AttachStdin: r.UseTTY, + AttachStderr: true, + AttachStdout: true, + Cmd: cmd, + Privileged: r.Privileged, + WorkingDir: r.WorkDir, + Env: env, + Tty: r.UseTTY, + User: r.User, + } + +} + +// Process represents a containerized process. +type Process struct { + container *Container + execid string + conn types.HijackedResponse +} + +// Write writes buf to the process's stdin. +func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) { + p.conn.Conn.SetDeadline(time.Now().Add(timeout)) + return p.conn.Conn.Write(buf) +} + +// Read returns process's stdout and stderr. +func (p *Process) Read() (string, string, error) { + var stdout, stderr bytes.Buffer + if err := p.read(&stdout, &stderr); err != nil { + return "", "", err + } + return stdout.String(), stderr.String(), nil +} + +// Logs returns combined stdout/stderr from the process. +func (p *Process) Logs() (string, error) { + var out bytes.Buffer + if err := p.read(&out, &out); err != nil { + return "", err + } + return out.String(), nil +} + +func (p *Process) read(stdout, stderr *bytes.Buffer) error { + _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader) + return err +} + +// ExitCode returns the process's exit code. +func (p *Process) ExitCode(ctx context.Context) (int, error) { + _, exitCode, err := p.runningExitCode(ctx) + return exitCode, err +} + +// IsRunning checks if the process is running. +func (p *Process) IsRunning(ctx context.Context) (bool, error) { + running, _, err := p.runningExitCode(ctx) + return running, err +} + +// WaitExitStatus until process completes and returns exit status. +func (p *Process) WaitExitStatus(ctx context.Context) (int, error) { + waitChan := make(chan (int)) + errChan := make(chan (error)) + + go func() { + for { + running, exitcode, err := p.runningExitCode(ctx) + if err != nil { + errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name) + } + if !running { + waitChan <- exitcode + } + time.Sleep(time.Millisecond * 500) + } + }() + + select { + case ws := <-waitChan: + return ws, nil + case err := <-errChan: + return -1, err + } +} + +// runningExitCode collects if the process is running and the exit code. +// The exit code is only valid if the process has exited. +func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) { + // If execid is not empty, this is a execed process. + if p.execid != "" { + status, err := p.container.client.ContainerExecInspect(ctx, p.execid) + return status.Running, status.ExitCode, err + } + // else this is the root process. + status, err := p.container.Status(ctx) + return status.Running, status.ExitCode, err +} diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go new file mode 100644 index 000000000..047091e75 --- /dev/null +++ b/pkg/test/dockerutil/network.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "net" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Network is a docker network. +type Network struct { + client *client.Client + id string + logger testutil.Logger + Name string + containers []*Container + Subnet *net.IPNet +} + +// NewNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewNetwork(ctx context.Context, logger testutil.Logger) *Network { + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + logger.Logf("create client failed with: %v", err) + return nil + } + client.NegotiateAPIVersion(ctx) + + return &Network{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + client: client, + } +} + +func (n *Network) networkCreate() types.NetworkCreate { + + var subnet string + if n.Subnet != nil { + subnet = n.Subnet.String() + } + + ipam := network.IPAM{ + Config: []network.IPAMConfig{{ + Subnet: subnet, + }}, + } + + return types.NetworkCreate{ + CheckDuplicate: true, + IPAM: &ipam, + } +} + +// Create is analogous to 'docker network create'. +func (n *Network) Create(ctx context.Context) error { + + opts := n.networkCreate() + resp, err := n.client.NetworkCreate(ctx, n.Name, opts) + if err != nil { + return err + } + n.id = resp.ID + return nil +} + +// Connect is analogous to 'docker network connect' with the arguments provided. +func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error { + settings := network.EndpointSettings{ + IPAMConfig: &network.EndpointIPAMConfig{ + IPv4Address: ipv4, + IPv6Address: ipv6, + }, + } + err := n.client.NetworkConnect(ctx, n.id, container.id, &settings) + if err == nil { + n.containers = append(n.containers, container) + } + return err +} + +// Inspect returns this network's info. +func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { + return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *Network) Cleanup(ctx context.Context) error { + for _, c := range n.containers { + c.CleanUp(ctx) + } + n.containers = nil + + return n.client.NetworkRemove(ctx, n.id) +} diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go new file mode 100644 index 000000000..1fab33083 --- /dev/null +++ b/pkg/test/dockerutil/profile.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Profile represents profile-like operations on a container, +// such as running perf or pprof. It is meant to be added to containers +// such that the container type calls the Profile during its lifecycle. +type Profile interface { + // OnCreate is called just after the container is created when the container + // has a valid ID (e.g. c.ID()). + OnCreate(c *Container) error + + // OnStart is called just after the container is started when the container + // has a valid Pid (e.g. c.SandboxPid()). + OnStart(c *Container) error + + // Restart restarts the Profile on request. + Restart(c *Container) error + + // OnCleanUp is called during the container's cleanup method. + // Cleanups should just log errors if they have them. + OnCleanUp(c *Container) error +} + +// Pprof is for running profiles with 'runsc debug'. Pprof workloads +// should be run as root and ONLY against runsc sandboxes. The runtime +// should have --profile set as an option in /etc/docker/daemon.json in +// order for profiling to work with Pprof. +type Pprof struct { + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + GoRoutineProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser +} + +// MakePprofFromFlags makes a Pprof profile from flags. +func MakePprofFromFlags(c *Container) *Pprof { + if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) { + return nil + } + return &Pprof{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + GoRoutineProfile: *pprofGo, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, + } +} + +// OnCreate implements Profile.OnCreate. +func (p *Pprof) OnCreate(c *Container) error { + return os.MkdirAll(p.BasePath, 0755) +} + +// OnStart implements Profile.OnStart. +func (p *Pprof) OnStart(c *Container) error { + path, err := RuntimePath() + if err != nil { + return fmt.Errorf("failed to get runtime path: %v", err) + } + + // The root directory of this container's runtime. + root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + args := []string{root, "debug"} + args = append(args, p.makeProfileArgs(c)...) + args = append(args, c.ID()) + + // Best effort wait until container is running. + for now := time.Now(); time.Since(now) < 5*time.Second; { + if status, err := c.Status(context.Background()); err != nil { + return fmt.Errorf("failed to get status with: %v", err) + + } else if status.Running { + break + } + time.Sleep(500 * time.Millisecond) + } + p.cmd = exec.Command(path, args...) + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("process failed: %v", err) + } + return nil +} + +// Restart implements Profile.Restart. +func (p *Pprof) Restart(c *Container) error { + p.OnCleanUp(c) + return p.OnStart(c) +} + +// OnCleanUp implements Profile.OnCleanup +func (p *Pprof) OnCleanUp(c *Container) error { + defer func() { p.cmd = nil }() + if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { + return p.cmd.Process.Kill() + } + return nil +} + +// makeProfileArgs turns Pprof fields into runsc debug flags. +func (p *Pprof) makeProfileArgs(c *Container) []string { + var ret []string + if p.BlockProfile { + ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) + } + if p.CPUProfile { + ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) + } + if p.GoRoutineProfile { + ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof"))) + } + if p.HeapProfile { + ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) + } + if p.MutexProfile { + ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) + } + ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) + return ret +} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go new file mode 100644 index 000000000..b7b4d7618 --- /dev/null +++ b/pkg/test/dockerutil/profile_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +type testCase struct { + name string + pprof Pprof + expectedFiles []string +} + +func TestPprof(t *testing.T) { + // Basepath and expected file names for each type of profile. + basePath := "/tmp/test/profile" + block := "block.pprof" + cpu := "cpu.pprof" + goprofle := "go.pprof" + heap := "heap.pprof" + mutex := "mutex.pprof" + + testCases := []testCase{ + { + name: "Cpu", + pprof: Pprof{ + BasePath: basePath, + CPUProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{cpu}, + }, + { + name: "All", + pprof: Pprof{ + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + GoRoutineProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. + tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) + c.AddProfile(&tc.pprof) + + func() { + defer c.CleanUp(ctx) + // Start a container. + if err := c.Spawn(ctx, RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { + t.Fatalf("run failed with: %v", err) + } + + if status, err := c.Status(context.Background()); !status.Running { + t.Fatalf("container is not yet running: %+v err: %v", status, err) + } + + // End early if the expected files exist and have data. + for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { + if err := checkFiles(tc); err == nil { + break + } + } + }() + + // Check all expected files exist and have data. + if err := checkFiles(tc); err != nil { + t.Fatalf(err.Error()) + } + }) + } +} + +func checkFiles(tc testCase) error { + for _, file := range tc.expectedFiles { + stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) + if err != nil { + return fmt.Errorf("stat failed with: %v", err) + } else if stat.Size() < 1 { + return fmt.Errorf("file not written to: %+v", stat) + } + } + return nil +} + +func TestMain(m *testing.M) { + EnsureSupportedDockerVersion() + os.Exit(m.Run()) +} diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index 03b1b4677..2d8f56bc0 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -15,6 +15,6 @@ go_library( "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index f21d6769a..64c292698 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -482,6 +482,21 @@ func IsStatic(filename string) (bool, error) { return true, nil } +// TouchShardStatusFile indicates to Bazel that the test runner supports +// sharding by creating or updating the last modified date of the file +// specified by TEST_SHARD_STATUS_FILE. +// +// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. +func TouchShardStatusFile() error { + if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" { + cmd := exec.Command("touch", statusFile) + if b, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) + } + } + return nil +} + // TestIndicesForShard returns indices for this test shard based on the // TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. // diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 707eb085b..67a950444 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -128,13 +128,6 @@ type EntryCallback interface { // // +stateify savable type Entry struct { - // Context stores any state the waiter may wish to store in the entry - // itself, which may be used at wake up time. - // - // Note that use of this field is optional and state may alternatively be - // stored in the callback itself. - Context interface{} - Callback EntryCallback // The following fields are protected by the queue lock. @@ -142,13 +135,14 @@ type Entry struct { waiterEntry } -type channelCallback struct{} +type channelCallback struct { + ch chan struct{} +} // Callback implements EntryCallback.Callback. -func (*channelCallback) Callback(e *Entry) { - ch := e.Context.(chan struct{}) +func (c *channelCallback) Callback(*Entry) { select { - case ch <- struct{}{}: + case c.ch <- struct{}{}: default: } } @@ -164,7 +158,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { c = make(chan struct{}, 1) } - return Entry{Context: c, Callback: &channelCallback{}}, c + return Entry{Callback: &channelCallback{ch: c}}, c } // Queue represents the wait queue where waiters can be added and |