diff options
Diffstat (limited to 'pkg')
486 files changed, 21953 insertions, 8970 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index b5c5cc20b..cdcaa8c73 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -74,9 +74,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/usermem", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go index 86ee3f8b5..5fc099892 100644 --- a/pkg/abi/linux/aio.go +++ b/pkg/abi/linux/aio.go @@ -42,6 +42,8 @@ const ( // // The priority field is currently ignored in the implementation below. Also // note that the IOCB_FLAG_RESFD feature is not supported. +// +// +marshal type IOCallback struct { Data uint64 Key uint32 @@ -64,6 +66,7 @@ type IOCallback struct { // IOEvent describes an I/O result. // +// +marshal // +stateify savable type IOEvent struct { Data uint64 diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go index aa3d3ce70..9422fcf69 100644 --- a/pkg/abi/linux/bpf.go +++ b/pkg/abi/linux/bpf.go @@ -16,6 +16,7 @@ package linux // BPFInstruction is a raw BPF virtual machine instruction. // +// +marshal slice:BPFInstructionSlice // +stateify savable type BPFInstruction struct { // OpCode is the operation to execute. diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go index 965f74663..afd16cc27 100644 --- a/pkg/abi/linux/capability.go +++ b/pkg/abi/linux/capability.go @@ -177,12 +177,16 @@ const ( ) // CapUserHeader is equivalent to Linux's cap_user_header_t. +// +// +marshal type CapUserHeader struct { Version uint32 Pid int32 } // CapUserData is equivalent to Linux's cap_user_data_t. +// +// +marshal slice:CapUserDataSlice type CapUserData struct { Effective uint32 Permitted uint32 diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index 192e2093b..7771650b3 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -54,9 +54,9 @@ const ( // Unix98 PTY masters. UNIX98_PTY_MASTER_MAJOR = 128 - // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for - // Unix98 PTY slaves. - UNIX98_PTY_SLAVE_MAJOR = 136 + // UNIX98_PTY_REPLICA_MAJOR is the initial major device number for + // Unix98 PTY replicas. + UNIX98_PTY_REPLICA_MAJOR = 136 ) // Minor device numbers for TTYAUX_MAJOR. diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index 9242e80a5..cc3571fad 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -45,6 +45,8 @@ const ( ) // Flock is the lock structure for F_SETLK. +// +// +marshal type Flock struct { Type int16 Whence int16 @@ -63,6 +65,8 @@ const ( ) // FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +// +// +marshal type FOwnerEx struct { Type int32 PID int32 diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index 7e30483ee..d91c97a64 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -14,12 +14,20 @@ package linux +import ( + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + // +marshal type FUSEOpcode uint32 // +marshal type FUSEOpID uint64 +// FUSE_ROOT_ID is the id of root inode. +const FUSE_ROOT_ID = 1 + // Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. const ( FUSE_LOOKUP FUSEOpcode = 1 @@ -116,61 +124,28 @@ type FUSEHeaderOut struct { Unique FUSEOpID } -// FUSEWriteIn is the header written by a daemon when it makes a -// write request to the FUSE filesystem. -// -// +marshal -type FUSEWriteIn struct { - // Fh specifies the file handle that is being written to. - Fh uint64 - - // Offset is the offset of the write. - Offset uint64 - - // Size is the size of data being written. - Size uint32 - - // WriteFlags is the flags used during the write. - WriteFlags uint32 - - // LockOwner is the ID of the lock owner. - LockOwner uint64 - - // Flags is the flags for the request. - Flags uint32 - - _ uint32 -} - // FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h. +// Our taget version is 7.23 but we have few implemented in advance. const ( - FUSE_ASYNC_READ = 1 << 0 - FUSE_POSIX_LOCKS = 1 << 1 - FUSE_FILE_OPS = 1 << 2 - FUSE_ATOMIC_O_TRUNC = 1 << 3 - FUSE_EXPORT_SUPPORT = 1 << 4 - FUSE_BIG_WRITES = 1 << 5 - FUSE_DONT_MASK = 1 << 6 - FUSE_SPLICE_WRITE = 1 << 7 - FUSE_SPLICE_MOVE = 1 << 8 - FUSE_SPLICE_READ = 1 << 9 - FUSE_FLOCK_LOCKS = 1 << 10 - FUSE_HAS_IOCTL_DIR = 1 << 11 - FUSE_AUTO_INVAL_DATA = 1 << 12 - FUSE_DO_READDIRPLUS = 1 << 13 - FUSE_READDIRPLUS_AUTO = 1 << 14 - FUSE_ASYNC_DIO = 1 << 15 - FUSE_WRITEBACK_CACHE = 1 << 16 - FUSE_NO_OPEN_SUPPORT = 1 << 17 - FUSE_PARALLEL_DIROPS = 1 << 18 - FUSE_HANDLE_KILLPRIV = 1 << 19 - FUSE_POSIX_ACL = 1 << 20 - FUSE_ABORT_ERROR = 1 << 21 - FUSE_MAX_PAGES = 1 << 22 - FUSE_CACHE_SYMLINKS = 1 << 23 - FUSE_NO_OPENDIR_SUPPORT = 1 << 24 - FUSE_EXPLICIT_INVAL_DATA = 1 << 25 - FUSE_MAP_ALIGNMENT = 1 << 26 + FUSE_ASYNC_READ = 1 << 0 + FUSE_POSIX_LOCKS = 1 << 1 + FUSE_FILE_OPS = 1 << 2 + FUSE_ATOMIC_O_TRUNC = 1 << 3 + FUSE_EXPORT_SUPPORT = 1 << 4 + FUSE_BIG_WRITES = 1 << 5 + FUSE_DONT_MASK = 1 << 6 + FUSE_SPLICE_WRITE = 1 << 7 + FUSE_SPLICE_MOVE = 1 << 8 + FUSE_SPLICE_READ = 1 << 9 + FUSE_FLOCK_LOCKS = 1 << 10 + FUSE_HAS_IOCTL_DIR = 1 << 11 + FUSE_AUTO_INVAL_DATA = 1 << 12 + FUSE_DO_READDIRPLUS = 1 << 13 + FUSE_READDIRPLUS_AUTO = 1 << 14 + FUSE_ASYNC_DIO = 1 << 15 + FUSE_WRITEBACK_CACHE = 1 << 16 + FUSE_NO_OPEN_SUPPORT = 1 << 17 + FUSE_MAX_PAGES = 1 << 22 // From FUSE 7.28 ) // currently supported FUSE protocol version numbers. @@ -179,6 +154,13 @@ const ( FUSE_KERNEL_MINOR_VERSION = 31 ) +// Constants relevant to FUSE operations. +const ( + FUSE_NAME_MAX = 1024 + FUSE_PAGE_SIZE = 4096 + FUSE_DIRENT_ALIGN = 8 +) + // FUSEInitIn is the request sent by the kernel to the daemon, // to negotiate the version and flags. // @@ -199,7 +181,7 @@ type FUSEInitIn struct { } // FUSEInitOut is the reply sent by the daemon to the kernel -// for FUSEInitIn. +// for FUSEInitIn. We target FUSE 7.23; this struct supports 7.28. // // +marshal type FUSEInitOut struct { @@ -240,13 +222,16 @@ type FUSEInitOut struct { // if the value from daemon is too large. MaxPages uint16 - // MapAlignment is an unknown field and not used by this package at this moment. - // Use as a placeholder to be consistent with the FUSE protocol. - MapAlignment uint16 + _ uint16 _ [8]uint32 } +// FUSE_GETATTR_FH is currently the only flag of FUSEGetAttrIn.GetAttrFlags. +// If it is set, the file handle (FUSEGetAttrIn.Fh) is used to indicate the +// object instead of the node id attribute in the request header. +const FUSE_GETATTR_FH = (1 << 0) + // FUSEGetAttrIn is the request sent by the kernel to the daemon, // to get the attribute of a inode. // @@ -267,22 +252,52 @@ type FUSEGetAttrIn struct { // // +marshal type FUSEAttr struct { - Ino uint64 - Size uint64 - Blocks uint64 - Atime uint64 - Mtime uint64 - Ctime uint64 + // Ino is the inode number of this file. + Ino uint64 + + // Size is the size of this file. + Size uint64 + + // Blocks is the number of the 512B blocks allocated by this file. + Blocks uint64 + + // Atime is the time of last access. + Atime uint64 + + // Mtime is the time of last modification. + Mtime uint64 + + // Ctime is the time of last status change. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. CtimeNsec uint32 - Mode uint32 - Nlink uint32 - UID uint32 - GID uint32 - Rdev uint32 - BlkSize uint32 - _ uint32 + + // Mode contains the file type and mode. + Mode uint32 + + // Nlink is the number of the hard links. + Nlink uint32 + + // UID is user ID of the owner. + UID uint32 + + // GID is group ID of the owner. + GID uint32 + + // Rdev is the device ID if this is a special file. + Rdev uint32 + + // BlkSize is the block size for filesystem I/O. + BlkSize uint32 + + _ uint32 } // FUSEGetAttrOut is the reply sent by the daemon to the kernel @@ -301,3 +316,558 @@ type FUSEGetAttrOut struct { // Attr contains the metadata returned from the FUSE server Attr FUSEAttr } + +// FUSEEntryOut is the reply sent by the daemon to the kernel +// for FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and +// FUSE_LOOKUP. +// +// +marshal +type FUSEEntryOut struct { + // NodeID is the ID for current inode. + NodeID uint64 + + // Generation is the generation number of inode. + // Used to identify an inode that have different ID at different time. + Generation uint64 + + // EntryValid indicates timeout for an entry. + EntryValid uint64 + + // AttrValid indicates timeout for an entry's attributes. + AttrValid uint64 + + // EntryValidNsec indicates timeout for an entry in nanosecond. + EntryValidNSec uint32 + + // AttrValidNsec indicates timeout for an entry's attributes in nanosecond. + AttrValidNSec uint32 + + // Attr contains the attributes of an entry. + Attr FUSEAttr +} + +// FUSELookupIn is the request sent by the kernel to the daemon +// to look up a file name. +// +// Dynamically-sized objects cannot be marshalled. +type FUSELookupIn struct { + marshal.StubMarshallable + + // Name is a file name to be looked up. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSELookupIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSELookupIn. +// 1 extra byte for null-terminated string. +func (r *FUSELookupIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// MAX_NON_LFS indicates the maximum offset without large file support. +const MAX_NON_LFS = ((1 << 31) - 1) + +// flags returned by OPEN request. +const ( + // FOPEN_DIRECT_IO indicates bypassing page cache for this opened file. + FOPEN_DIRECT_IO = 1 << 0 + // FOPEN_KEEP_CACHE avoids invalidate of data cache on open. + FOPEN_KEEP_CACHE = 1 << 1 + // FOPEN_NONSEEKABLE indicates the file cannot be seeked. + FOPEN_NONSEEKABLE = 1 << 2 +) + +// FUSEOpenIn is the request sent by the kernel to the daemon, +// to negotiate flags and get file handle. +// +// +marshal +type FUSEOpenIn struct { + // Flags of this open request. + Flags uint32 + + _ uint32 +} + +// FUSEOpenOut is the reply sent by the daemon to the kernel +// for FUSEOpenIn. +// +// +marshal +type FUSEOpenOut struct { + // Fh is the file handler for opened file. + Fh uint64 + + // OpenFlag for the opened file. + OpenFlag uint32 + + _ uint32 +} + +// FUSE_READ flags, consistent with the ones in include/uapi/linux/fuse.h. +const ( + FUSE_READ_LOCKOWNER = 1 << 1 +) + +// FUSEReadIn is the request sent by the kernel to the daemon +// for FUSE_READ. +// +// +marshal +type FUSEReadIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the read offset. + Offset uint64 + + // Size is the number of bytes to read. + Size uint32 + + // ReadFlags for this FUSE_READ request. + // Currently only contains FUSE_READ_LOCKOWNER. + ReadFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteIn is the first part of the payload of the +// request sent by the kernel to the daemon +// for FUSE_WRITE (struct for FUSE version >= 7.9). +// +// The second part of the payload is the +// binary bytes of the data to be written. +// +// +marshal +type FUSEWriteIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the write offset. + Offset uint64 + + // Size is the number of bytes to write. + Size uint32 + + // ReadFlags for this FUSE_WRITE request. + WriteFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteOut is the payload of the reply sent by the daemon to the kernel +// for a FUSE_WRITE request. +// +// +marshal +type FUSEWriteOut struct { + // Size is the number of bytes written. + Size uint32 + + _ uint32 +} + +// FUSEReleaseIn is the request sent by the kernel to the daemon +// when there is no more reference to a file. +// +// +marshal +type FUSEReleaseIn struct { + // Fh is the file handler for the file to be released. + Fh uint64 + + // Flags of the file. + Flags uint32 + + // ReleaseFlags of this release request. + ReleaseFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 +} + +// FUSECreateMeta contains all the static fields of FUSECreateIn, +// which is used for FUSE_CREATE. +// +// +marshal +type FUSECreateMeta struct { + // Flags of the creating file. + Flags uint32 + + // Mode is the mode of the creating file. + Mode uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + _ uint32 +} + +// FUSECreateIn contains all the arguments sent by the kernel to the daemon, to +// atomically create and open a new regular file. +// +// Dynamically-sized objects cannot be marshalled. +type FUSECreateIn struct { + marshal.StubMarshallable + + // CreateMeta contains mode, rdev and umash field for FUSE_MKNODS. + CreateMeta FUSECreateMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer. +func (r *FUSECreateIn) MarshalBytes(buf []byte) { + r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()]) + copy(buf[r.CreateMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSECreateIn. +// 1 extra byte for null-terminated string. +func (r *FUSECreateIn) SizeBytes() int { + return r.CreateMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSEMknodMeta contains all the static fields of FUSEMknodIn, +// which is used for FUSE_MKNOD. +// +// +marshal +type FUSEMknodMeta struct { + // Mode of the inode to create. + Mode uint32 + + // Rdev encodes device major and minor information. + Rdev uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + + _ uint32 +} + +// FUSEMknodIn contains all the arguments sent by the kernel +// to the daemon, to create a new file node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMknodIn struct { + marshal.StubMarshallable + + // MknodMeta contains mode, rdev and umash field for FUSE_MKNODS. + MknodMeta FUSEMknodMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer. +func (r *FUSEMknodIn) MarshalBytes(buf []byte) { + r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()]) + copy(buf[r.MknodMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMknodIn. +// 1 extra byte for null-terminated string. +func (r *FUSEMknodIn) SizeBytes() int { + return r.MknodMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSESymLinkIn is the request sent by the kernel to the daemon, +// to create a symbolic link. +// +// Dynamically-sized objects cannot be marshalled. +type FUSESymLinkIn struct { + marshal.StubMarshallable + + // Name of symlink to create. + Name string + + // Target of the symlink. + Target string +} + +// MarshalBytes serializes r.Name and r.Target to the dst buffer. +// Left null-termination at end of r.Name and r.Target. +func (r *FUSESymLinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) + copy(buf[len(r.Name)+1:], r.Target) +} + +// SizeBytes is the size of the memory representation of FUSESymLinkIn. +// 2 extra bytes for null-terminated string. +func (r *FUSESymLinkIn) SizeBytes() int { + return len(r.Name) + len(r.Target) + 2 +} + +// FUSEEmptyIn is used by operations without request body. +type FUSEEmptyIn struct{ marshal.StubMarshallable } + +// MarshalBytes do nothing for marshal. +func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {} + +// SizeBytes is 0 for empty request. +func (r *FUSEEmptyIn) SizeBytes() int { + return 0 +} + +// FUSEMkdirMeta contains all the static fields of FUSEMkdirIn, +// which is used for FUSE_MKDIR. +// +// +marshal +type FUSEMkdirMeta struct { + // Mode of the directory of create. + Mode uint32 + + // Umask is the user file creation mask. + Umask uint32 +} + +// FUSEMkdirIn contains all the arguments sent by the kernel +// to the daemon, to create a new directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMkdirIn struct { + marshal.StubMarshallable + + // MkdirMeta contains Mode and Umask of the directory to create. + MkdirMeta FUSEMkdirMeta + + // Name of the directory to create. + Name string +} + +// MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer. +func (r *FUSEMkdirIn) MarshalBytes(buf []byte) { + r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()]) + copy(buf[r.MkdirMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMkdirIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEMkdirIn) SizeBytes() int { + return r.MkdirMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSERmDirIn is the request sent by the kernel to the daemon +// when trying to remove a directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSERmDirIn struct { + marshal.StubMarshallable + + // Name is a directory name to be removed. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSERmDirIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSERmDirIn. +func (r *FUSERmDirIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// FUSEDirents is a list of Dirents received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirents struct { + marshal.StubMarshallable + + Dirents []*FUSEDirent +} + +// FUSEDirent is a Dirent received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirent struct { + marshal.StubMarshallable + + // Meta contains all the static fields of FUSEDirent. + Meta FUSEDirentMeta + + // Name is the filename of the dirent. + Name string +} + +// FUSEDirentMeta contains all the static fields of FUSEDirent. +// It is used for FUSE_READDIR. +// +// +marshal +type FUSEDirentMeta struct { + // Inode of the dirent. + Ino uint64 + + // Offset of the dirent. + Off uint64 + + // NameLen is the length of the dirent name. + NameLen uint32 + + // Type of the dirent. + Type uint32 +} + +// SizeBytes is the size of the memory representation of FUSEDirents. +func (r *FUSEDirents) SizeBytes() int { + var sizeBytes int + for _, dirent := range r.Dirents { + sizeBytes += dirent.SizeBytes() + } + + return sizeBytes +} + +// UnmarshalBytes deserializes FUSEDirents from the src buffer. +func (r *FUSEDirents) UnmarshalBytes(src []byte) { + for { + if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() { + break + } + + // Its unclear how many dirents there are in src. Each dirent is dynamically + // sized and so we can't make assumptions about how many dirents we can allocate. + if r.Dirents == nil { + r.Dirents = make([]*FUSEDirent, 0) + } + + // We have to allocate a struct for each dirent - there must be a better way + // to do this. Linux allocates 1 page to store all the dirents and then + // simply reads them from the page. + var dirent FUSEDirent + dirent.UnmarshalBytes(src) + r.Dirents = append(r.Dirents, &dirent) + + src = src[dirent.SizeBytes():] + } +} + +// SizeBytes is the size of the memory representation of FUSEDirent. +func (r *FUSEDirent) SizeBytes() int { + dataSize := r.Meta.SizeBytes() + len(r.Name) + + // Each Dirent must be padded such that its size is a multiple + // of FUSE_DIRENT_ALIGN. Similar to the fuse dirent alignment + // in linux/fuse.h. + return (dataSize + (FUSE_DIRENT_ALIGN - 1)) & ^(FUSE_DIRENT_ALIGN - 1) +} + +// UnmarshalBytes deserializes FUSEDirent from the src buffer. +func (r *FUSEDirent) UnmarshalBytes(src []byte) { + r.Meta.UnmarshalBytes(src) + src = src[r.Meta.SizeBytes():] + + if r.Meta.NameLen > FUSE_NAME_MAX { + // The name is too long and therefore invalid. We don't + // need to unmarshal the name since it'll be thrown away. + return + } + + buf := make([]byte, r.Meta.NameLen) + name := primitive.ByteSlice(buf) + name.UnmarshalBytes(src[:r.Meta.NameLen]) + r.Name = string(name) +} + +// FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h. +// These should be or-ed together for setattr to know what has been changed. +const ( + FATTR_MODE = (1 << 0) + FATTR_UID = (1 << 1) + FATTR_GID = (1 << 2) + FATTR_SIZE = (1 << 3) + FATTR_ATIME = (1 << 4) + FATTR_MTIME = (1 << 5) + FATTR_FH = (1 << 6) + FATTR_ATIME_NOW = (1 << 7) + FATTR_MTIME_NOW = (1 << 8) + FATTR_LOCKOWNER = (1 << 9) + FATTR_CTIME = (1 << 10) +) + +// FUSESetAttrIn is the request sent by the kernel to the daemon, +// to set the attribute(s) of a file. +// +// +marshal +type FUSESetAttrIn struct { + // Valid indicates which attributes are modified by this request. + Valid uint32 + + _ uint32 + + // Fh is used to identify the file if FATTR_FH is set in Valid. + Fh uint64 + + // Size is the size that the request wants to change to. + Size uint64 + + // LockOwner is the owner of the lock that the request wants to change to. + LockOwner uint64 + + // Atime is the access time that the request wants to change to. + Atime uint64 + + // Mtime is the modification time that the request wants to change to. + Mtime uint64 + + // Ctime is the status change time that the request wants to change to. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. + AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. + MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. + CtimeNsec uint32 + + // Mode is the file mode that the request wants to change to. + Mode uint32 + + _ uint32 + + // UID is the user ID of the owner that the request wants to change to. + UID uint32 + + // GID is the group ID of the owner that the request wants to change to. + GID uint32 + + _ uint32 +} + +// FUSEUnlinkIn is the request sent by the kernel to the daemon +// when trying to unlink a node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEUnlinkIn struct { + marshal.StubMarshallable + + // Name of the node to unlink. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer, which should +// have size len(r.Name) + 1 and last byte set to 0. +func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEUnlinkIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEUnlinkIn) SizeBytes() int { + return len(r.Name) + 1 +} diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index d6dbedc3e..7df02dd6d 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -113,6 +113,35 @@ const ( _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS ) +// Constants from uapi/linux/fs.h. +const ( + FS_IOC_GETFLAGS = 2148034049 + FS_VERITY_FL = 1048576 +) + +// Constants from uapi/linux/fsverity.h. +const ( + FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_MEASURE_VERITY = 3221513862 +) + +// DigestMetadata is a helper struct for VerityDigest. +// +// +marshal +type DigestMetadata struct { + DigestAlgorithm uint16 + DigestSize uint16 +} + +// SizeOfDigestMetadata is the size of struct DigestMetadata. +const SizeOfDigestMetadata = 4 + +// VerityDigest is struct from uapi/linux/fsverity.h. +type VerityDigest struct { + Metadata DigestMetadata + Digest []byte +} + // IOC outputs the result of _IOC macro in asm-generic/ioctl.h. func IOC(dir, typ, nr, size uint32) uint32 { return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go index 22acd2d43..c6e65df62 100644 --- a/pkg/abi/linux/ipc.go +++ b/pkg/abi/linux/ipc.go @@ -37,6 +37,8 @@ const IPC_PRIVATE = 0 // features like 32-bit UIDs. // IPCPerm is equivalent to struct ipc64_perm. +// +// +marshal type IPCPerm struct { Key uint32 UID uint32 diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go index 281acdbde..3b4abece1 100644 --- a/pkg/abi/linux/linux.go +++ b/pkg/abi/linux/linux.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux contains the constants and types needed to interface with a Linux kernel. +// Package linux contains the constants and types needed to interface with a +// Linux kernel. package linux // NumSoftIRQ is the number of software IRQs, exposed via /proc/stat. @@ -21,6 +22,8 @@ package linux const NumSoftIRQ = 10 // Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48. +// +// +marshal type Sysinfo struct { Uptime int64 Loads [3]uint64 @@ -34,6 +37,6 @@ type Sysinfo struct { _ [6]byte // Pad Procs to 64bits. TotalHigh uint64 FreeHigh uint64 - Unit uint32 - /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */ + Unit uint32 `marshal:"unaligned"` // Struct ends mid-64-bit-word. + // The _f field in the glibc version of Sysinfo has size 0 on AMD64. } diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 91e35366f..b521144d9 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -17,9 +17,9 @@ package linux import ( "io" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // This file contains structures required to support netfilter, specifically @@ -265,6 +265,18 @@ type KernelXTEntryMatch struct { Data []byte } +// XTGetRevision corresponds to xt_get_revision in +// include/uapi/linux/netfilter/x_tables.h +// +// +marshal +type XTGetRevision struct { + Name ExtensionName + Revision uint8 +} + +// SizeOfXTGetRevision is the size of an XTGetRevision. +const SizeOfXTGetRevision = 30 + // XTEntryTarget holds a target for a rule. For example, it can specify that // packets matching the rule should DROP, ACCEPT, or use an extension target. // iptables-extension(8) has a list of possible targets. @@ -285,6 +297,13 @@ type XTEntryTarget struct { // SizeOfXTEntryTarget is the size of an XTEntryTarget. const SizeOfXTEntryTarget = 32 +// KernelXTEntryTarget is identical to XTEntryTarget, but contains a +// variable-length Data field. +type KernelXTEntryTarget struct { + XTEntryTarget + Data []byte +} + // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. @@ -450,9 +469,9 @@ func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := task.CopyInBytes(addr, buf) // escapes: okay. +func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. // Unmarshal unconditionally. If we had a short copy-in, this results in a // partially unmarshalled struct. ke.UnmarshalBytes(buf) // escapes: fallback. @@ -460,21 +479,21 @@ func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int } // CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { +func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { +func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) } -func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { - buf := task.CopyScratchBuffer(ke.SizeBytes()) +func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) ke.MarshalBytes(buf) return buf } @@ -510,6 +529,8 @@ type IPTReplace struct { const SizeOfIPTReplace = 96 // ExtensionName holds the name of a netfilter extension. +// +// +marshal type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index f6117024c..6d31eb5e3 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -17,9 +17,9 @@ package linux import ( "io" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // This file contains structures required to support IPv6 netfilter and @@ -128,9 +128,9 @@ func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := task.CopyInBytes(addr, buf) // escapes: okay. +func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. // Unmarshal unconditionally. If we had a short copy-in, this results // in a partially unmarshalled struct. ke.UnmarshalBytes(buf) // escapes: fallback. @@ -138,21 +138,21 @@ func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (in } // CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { +func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { // Type KernelIP6TGetEntries doesn't have a packed layout in memory, // fall back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { +func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) } -func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte { - buf := task.CopyScratchBuffer(ke.SizeBytes()) +func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) ke.MarshalBytes(buf) return buf } @@ -321,3 +321,16 @@ const ( // Enable all flags. IP6T_INV_MASK = 0x7F ) + +// NFNATRange corresponds to struct nf_nat_range in +// include/uapi/linux/netfilter/nf_nat.h. +type NFNATRange struct { + Flags uint32 + MinAddr Inet6Addr + MaxAddr Inet6Addr + MinProto uint16 // Network byte order. + MaxProto uint16 // Network byte order. +} + +// SizeOfNFNATRange is the size of NFNATRange. +const SizeOfNFNATRange = 40 diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 0ba086c76..b41f94a69 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -40,6 +40,8 @@ const ( ) // SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h. +// +// +marshal type SockAddrNetlink struct { Family uint16 _ uint16 diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go index c04d26e4c..3443a5768 100644 --- a/pkg/abi/linux/poll.go +++ b/pkg/abi/linux/poll.go @@ -15,6 +15,8 @@ package linux // PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h. +// +// +marshal slice:PollFDSlice type PollFD struct { FD int32 Events int16 diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go index d8302dc85..e29d0ac7e 100644 --- a/pkg/abi/linux/rusage.go +++ b/pkg/abi/linux/rusage.go @@ -26,6 +26,8 @@ const ( ) // Rusage represents the Linux struct rusage. +// +// +marshal type Rusage struct { UTime Timeval STime Timeval diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go index d0607e256..b07cafe12 100644 --- a/pkg/abi/linux/seccomp.go +++ b/pkg/abi/linux/seccomp.go @@ -34,11 +34,11 @@ type BPFAction uint32 const ( SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000 - SECCOMP_RET_KILL_THREAD = 0x00000000 - SECCOMP_RET_TRAP = 0x00030000 - SECCOMP_RET_ERRNO = 0x00050000 - SECCOMP_RET_TRACE = 0x7ff00000 - SECCOMP_RET_ALLOW = 0x7fff0000 + SECCOMP_RET_KILL_THREAD BPFAction = 0x00000000 + SECCOMP_RET_TRAP BPFAction = 0x00030000 + SECCOMP_RET_ERRNO BPFAction = 0x00050000 + SECCOMP_RET_TRACE BPFAction = 0x7ff00000 + SECCOMP_RET_ALLOW BPFAction = 0x7fff0000 ) func (a BPFAction) String() string { @@ -64,6 +64,19 @@ func (a BPFAction) Data() uint16 { return uint16(a & SECCOMP_RET_DATA) } +// WithReturnCode sets the lower 16 bits of the SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE actions to the provided return code, overwriting the previous +// action, and returns a new BPFAction. If not SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE then this panics. +func (a BPFAction) WithReturnCode(code uint16) BPFAction { + // mask out the previous return value + baseAction := a & SECCOMP_RET_ACTION_FULL + if baseAction == SECCOMP_RET_ERRNO || baseAction == SECCOMP_RET_TRACE { + return BPFAction(uint32(baseAction) | uint32(code)) + } + panic("WithReturnCode only valid for SECCOMP_RET_ERRNO and SECCOMP_RET_TRACE") +} + // SockFprog is sock_fprog taken from <linux/filter.h>. type SockFprog struct { Len uint16 diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index de422c519..487a626cc 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -35,6 +35,8 @@ const ( const SEM_UNDO = 0x1000 // SemidDS is equivalent to struct semid64_ds. +// +// +marshal type SemidDS struct { SemPerm IPCPerm SemOTime TimeT @@ -45,6 +47,8 @@ type SemidDS struct { } // Sembuf is equivalent to struct sembuf. +// +// +marshal slice:SembufSlice type Sembuf struct { SemNum uint16 SemOp int16 diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go index e45aadb10..274b1e847 100644 --- a/pkg/abi/linux/shm.go +++ b/pkg/abi/linux/shm.go @@ -51,6 +51,8 @@ const ( // ShmidDS is equivalent to struct shmid64_ds. Source: // include/uapi/asm-generic/shmbuf.h +// +// +marshal type ShmidDS struct { ShmPerm IPCPerm ShmSegsz uint64 @@ -66,6 +68,8 @@ type ShmidDS struct { } // ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h +// +// +marshal type ShmParams struct { ShmMax uint64 ShmMin uint64 @@ -75,6 +79,8 @@ type ShmParams struct { } // ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h +// +// +marshal type ShmInfo struct { UsedIDs int32 // Number of currently existing segments. _ [4]byte diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 1c330e763..6ca57ffbb 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -214,6 +214,8 @@ const ( ) // Sigevent represents struct sigevent. +// +// +marshal type Sigevent struct { Value uint64 // union sigval {int, void*} Signo int32 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index e37c8727d..d156d41e4 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -14,7 +14,10 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" +import ( + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" +) // Address families, from linux/socket.h. const ( @@ -265,6 +268,8 @@ type InetMulticastRequestWithNIC struct { type Inet6Addr [16]byte // SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h. +// +// +marshal type SockAddrInet6 struct { Family uint16 Port uint16 @@ -274,6 +279,8 @@ type SockAddrInet6 struct { } // SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +// +// +marshal type SockAddrLink struct { Family uint16 Protocol uint16 @@ -290,6 +297,8 @@ type SockAddrLink struct { const UnixPathMax = 108 // SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h. +// +// +marshal type SockAddrUnix struct { Family uint16 Path [UnixPathMax]int8 @@ -299,6 +308,8 @@ type SockAddrUnix struct { // equivalent to struct sockaddr. SockAddr ensures that a well-defined set of // types can be used as socket addresses. type SockAddr interface { + marshal.Marshallable + // implementsSockAddr exists purely to allow a type to indicate that they // implement this interface. This method is a no-op and shouldn't be called. implementsSockAddr() diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go index e6860ed49..206f5af7e 100644 --- a/pkg/abi/linux/time.go +++ b/pkg/abi/linux/time.go @@ -93,6 +93,8 @@ const ( const maxSecInDuration = math.MaxInt64 / int64(time.Second) // TimeT represents time_t in <time.h>. It represents time in seconds. +// +// +marshal type TimeT int64 // NsecToTimeT translates nanoseconds to TimeT (seconds). @@ -102,7 +104,7 @@ func NsecToTimeT(nsec int64) TimeT { // Timespec represents struct timespec in <time.h>. // -// +marshal +// +marshal slice:TimespecSlice type Timespec struct { Sec int64 Nsec int64 @@ -158,7 +160,7 @@ const SizeOfTimeval = 16 // Timeval represents struct timeval in <time.h>. // -// +marshal +// +marshal slice:TimevalSlice type Timeval struct { Sec int64 Usec int64 @@ -196,6 +198,8 @@ func DurationToTimeval(dur time.Duration) Timeval { } // Itimerspec represents struct itimerspec in <time.h>. +// +// +marshal type Itimerspec struct { Interval Timespec Value Timespec @@ -206,12 +210,16 @@ type Itimerspec struct { // struct timeval it_interval; /* next value */ // struct timeval it_value; /* current value */ // }; +// +// +marshal type ItimerVal struct { Interval Timeval Value Timeval } // ClockT represents type clock_t. +// +// +marshal type ClockT int64 // ClockTFromDuration converts time.Duration to clock_t. @@ -220,6 +228,8 @@ func ClockTFromDuration(d time.Duration) ClockT { } // Tms represents struct tms, used by times(2). +// +// +marshal type Tms struct { UTime ClockT STime ClockT @@ -229,6 +239,8 @@ type Tms struct { // TimerID represents type timer_t, which identifies a POSIX per-process // interval timer. +// +// +marshal type TimerID int32 // StatxTimestamp represents struct statx_timestamp. diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go index e640969a6..47e65d9fb 100644 --- a/pkg/abi/linux/tty.go +++ b/pkg/abi/linux/tty.go @@ -325,9 +325,9 @@ var MasterTermios = KernelTermios{ OutputSpeed: 38400, } -// DefaultSlaveTermios is the default terminal configuration of the slave end -// of a Unix98 pseudoterminal. -var DefaultSlaveTermios = KernelTermios{ +// DefaultReplicaTermios is the default terminal configuration of the replica +// end of a Unix98 pseudoterminal. +var DefaultReplicaTermios = KernelTermios{ InputFlags: ICRNL | IXON, OutputFlags: OPOST | ONLCR, ControlFlags: B38400 | CS8 | CREAD, @@ -341,6 +341,7 @@ var DefaultSlaveTermios = KernelTermios{ // include/uapi/asm-generic/termios.h. // // +stateify savable +// +marshal type WindowSize struct { Rows uint16 Cols uint16 diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go index 60f220a67..cb7c95437 100644 --- a/pkg/abi/linux/utsname.go +++ b/pkg/abi/linux/utsname.go @@ -26,6 +26,8 @@ const ( ) // UtsName represents struct utsname, the struct returned by uname(2). +// +// +marshal type UtsName struct { Sysname [UTSLen + 1]byte Nodename [UTSLen + 1]byte diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index ffc918846..bd3a5cce9 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -6,7 +6,10 @@ go_library( name = "amutex", srcs = ["amutex.go"], visibility = ["//:sandbox"], - deps = ["//pkg/syserror"], + deps = [ + "//pkg/context", + "//pkg/syserror", + ], ) go_test( diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go index a078a31db..d7acc1d9f 100644 --- a/pkg/amutex/amutex.go +++ b/pkg/amutex/amutex.go @@ -19,41 +19,17 @@ package amutex import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) // Sleeper must be implemented by users of the abortable mutex to allow for // cancellation of waits. -type Sleeper interface { - // SleepStart is called by the AbortableMutex.Lock() function when the - // mutex is contended and the goroutine is about to sleep. - // - // A channel can be returned that causes the sleep to be canceled if - // it's readable. If no cancellation is desired, nil can be returned. - SleepStart() <-chan struct{} - - // SleepFinish is called by AbortableMutex.Lock() once a contended mutex - // is acquired or the wait is aborted. - SleepFinish(success bool) - - // Interrupted returns true if the wait is aborted. - Interrupted() bool -} +type Sleeper = context.ChannelSleeper // NoopSleeper is a stateless no-op implementation of Sleeper for anonymous // embedding in other types that do not support cancelation. -type NoopSleeper struct{} - -// SleepStart implements Sleeper.SleepStart. -func (NoopSleeper) SleepStart() <-chan struct{} { - return nil -} - -// SleepFinish implements Sleeper.SleepFinish. -func (NoopSleeper) SleepFinish(success bool) {} - -// Interrupted implements Sleeper.Interrupted. -func (NoopSleeper) Interrupted() bool { return false } +type NoopSleeper = context.Context // Block blocks until either receiving from ch succeeds (in which case it // returns nil) or sleeper is interrupted (in which case it returns diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index c8ee0c3b1..069d0395d 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -21,10 +21,15 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// DecodeProgram translates an array of BPF instructions into text format. -func DecodeProgram(program []linux.BPFInstruction) (string, error) { +// DecodeProgram translates a compiled BPF program into text format. +func DecodeProgram(p Program) (string, error) { + return DecodeInstructions(p.instructions) +} + +// DecodeInstructions translates an array of BPF instructions into text format. +func DecodeInstructions(instns []linux.BPFInstruction) (string, error) { var ret bytes.Buffer - for line, s := range program { + for line, s := range instns { ret.WriteString(fmt.Sprintf("%v: ", line)) if err := decode(s, line, &ret); err != nil { return "", err @@ -34,7 +39,7 @@ func DecodeProgram(program []linux.BPFInstruction) (string, error) { return ret.String(), nil } -// Decode translates BPF instruction into text format. +// Decode translates a single BPF instruction into text format. func Decode(inst linux.BPFInstruction) (string, error) { var ret bytes.Buffer err := decode(inst, -1, &ret) diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go index 6a023f0c0..bb971ce21 100644 --- a/pkg/bpf/decoder_test.go +++ b/pkg/bpf/decoder_test.go @@ -93,7 +93,7 @@ func TestDecode(t *testing.T) { } } -func TestDecodeProgram(t *testing.T) { +func TestDecodeInstructions(t *testing.T) { for _, test := range []struct { name string program []linux.BPFInstruction @@ -126,7 +126,7 @@ func TestDecodeProgram(t *testing.T) { program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, fail: true}, } { - got, err := DecodeProgram(test.program) + got, err := DecodeInstructions(test.program) if test.fail { if err == nil { t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go index 7992044d0..caaf99c83 100644 --- a/pkg/bpf/program_builder.go +++ b/pkg/bpf/program_builder.go @@ -32,13 +32,21 @@ type ProgramBuilder struct { // Maps label names to label objects. labels map[string]*label + // unusableLabels are labels that are added before being referenced in a + // jump. Any labels added this way cannot be referenced later in order to + // avoid backwards references. + unusableLabels map[string]bool + // Array of BPF instructions that makes up the program. instructions []linux.BPFInstruction } // NewProgramBuilder creates a new ProgramBuilder instance. func NewProgramBuilder() *ProgramBuilder { - return &ProgramBuilder{labels: map[string]*label{}} + return &ProgramBuilder{ + labels: map[string]*label{}, + unusableLabels: map[string]bool{}, + } } // label contains information to resolve a label to an offset. @@ -108,9 +116,12 @@ func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel s func (b *ProgramBuilder) AddLabel(name string) error { l, ok := b.labels[name] if !ok { - // This is done to catch jump backwards cases, but it's not strictly wrong - // to have unused labels. - return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name) + if _, ok = b.unusableLabels[name]; ok { + return fmt.Errorf("label %q already set", name) + } + // Mark the label as unusable. This is done to catch backwards jumps. + b.unusableLabels[name] = true + return nil } if l.target != -1 { return fmt.Errorf("label %q target already set: %v", name, l.target) @@ -141,6 +152,10 @@ func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) { func (b *ProgramBuilder) resolveLabels() error { for key, v := range b.labels { + if _, ok := b.unusableLabels[key]; ok { + return fmt.Errorf("backwards reference detected for label: %q", key) + } + if v.target == -1 { return fmt.Errorf("label target not set: %v", key) } diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go index 92ca5f4c3..37f684f25 100644 --- a/pkg/bpf/program_builder_test.go +++ b/pkg/bpf/program_builder_test.go @@ -26,16 +26,16 @@ func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { if err != nil { return fmt.Errorf("Instructions() failed: %v", err) } - got, err := DecodeProgram(instructions) + got, err := DecodeInstructions(instructions) if err != nil { - return fmt.Errorf("DecodeProgram('instructions') failed: %v", err) + return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err) } - expectedDecoded, err := DecodeProgram(expected) + expectedDecoded, err := DecodeInstructions(expected) if err != nil { - return fmt.Errorf("DecodeProgram('expected') failed: %v", err) + return fmt.Errorf("DecodeInstructions('expected') failed: %v", err) } if got != expectedDecoded { - return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got) + return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got) } return nil } @@ -124,10 +124,38 @@ func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { } } +// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't +// cause program generation to fail. func TestProgramBuilderUnusedLabel(t *testing.T) { p := NewProgramBuilder() - if err := p.AddLabel("unused"); err == nil { - t.Errorf("AddLabel(unused) should have failed") + p.AddStmt(Ld+Abs+W, 10) + p.AddJump(Jmp+Ja, 10, 0, 0) + + expected := []linux.BPFInstruction{ + Stmt(Ld+Abs+W, 10), + Jump(Jmp+Ja, 10, 0, 0), + } + + if err := p.AddLabel("unused"); err != nil { + t.Errorf("AddLabel(unused) should have succeeded") + } + + if err := validate(p, expected); err != nil { + t.Errorf("Validate() failed: %v", err) + } +} + +// TestProgramBuilderBackwardsReference tests that including a backwards +// reference to a label in a program causes a failure. +func TestProgramBuilderBackwardsReference(t *testing.T) { + p := NewProgramBuilder() + if err := p.AddLabel("bw_label"); err != nil { + t.Errorf("failed to add label") + } + p.AddStmt(Ld+Abs+W, 10) + p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0) + if _, err := p.Instructions(); err == nil { + t.Errorf("Instructions() should have failed") } } diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index b03d46d18..1186f788e 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -20,6 +20,7 @@ go_library( srcs = [ "buffer.go", "buffer_list.go", + "pool.go", "safemem.go", "view.go", "view_unsafe.go", @@ -37,9 +38,13 @@ go_test( name = "buffer_test", size = "small", srcs = [ + "pool_test.go", "safemem_test.go", "view_test.go", ], library = ":buffer", - deps = ["//pkg/safemem"], + deps = [ + "//pkg/safemem", + "//pkg/state", + ], ) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index c6d089fd9..311808ae9 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -14,36 +14,26 @@ // Package buffer provides the implementation of a buffer view. // -// A view is an flexible buffer, backed by a pool, supporting the safecopy -// operations natively as well as the ability to grow via either prepend or -// append, as well as shrink. +// A view is an flexible buffer, supporting the safecopy operations natively as +// well as the ability to grow via either prepend or append, as well as shrink. package buffer -import ( - "sync" -) - -const bufferSize = 8144 // See below. - // buffer encapsulates a queueable byte buffer. // -// Note that the total size is slightly less than two pages. This is done -// intentionally to ensure that the buffer object aligns with runtime -// internals. We have no hard size or alignment requirements. This two page -// size will effectively minimize internal fragmentation, but still have a -// large enough chunk to limit excessive segmentation. -// // +stateify savable type buffer struct { - data [bufferSize]byte + data []byte read int write int bufferEntry } -// reset resets internal data. -// -// This must be called before returning the buffer to the pool. +// init performs in-place initialization for zero value. +func (b *buffer) init(size int) { + b.data = make([]byte, size) +} + +// Reset resets read and write locations, effectively emptying the buffer. func (b *buffer) Reset() { b.read = 0 b.write = 0 @@ -85,10 +75,3 @@ func (b *buffer) WriteMove(n int) { func (b *buffer) WriteSlice() []byte { return b.data[b.write:] } - -// bufferPool is a pool for buffers. -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(buffer) - }, -} diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go new file mode 100644 index 000000000..7ad6132ab --- /dev/null +++ b/pkg/buffer/pool.go @@ -0,0 +1,83 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +const ( + // embeddedCount is the number of buffer structures embedded in the pool. It + // is also the number for overflow allocations. + embeddedCount = 8 + + // defaultBufferSize is the default size for each underlying storage buffer. + // + // It is slightly less than two pages. This is done intentionally to ensure + // that the buffer object aligns with runtime internals. This two page size + // will effectively minimize internal fragmentation, but still have a large + // enough chunk to limit excessive segmentation. + defaultBufferSize = 8144 +) + +// pool allocates buffer. +// +// It contains an embedded buffer storage for fast path when the number of +// buffers needed is small. +// +// +stateify savable +type pool struct { + bufferSize int + avail []buffer `state:"nosave"` + embeddedStorage [embeddedCount]buffer `state:"wait"` +} + +// get gets a new buffer from p. +func (p *pool) get() *buffer { + if p.avail == nil { + p.avail = p.embeddedStorage[:] + } + if len(p.avail) == 0 { + p.avail = make([]buffer, embeddedCount) + } + if p.bufferSize <= 0 { + p.bufferSize = defaultBufferSize + } + buf := &p.avail[0] + buf.init(p.bufferSize) + p.avail = p.avail[1:] + return buf +} + +// put releases buf. +func (p *pool) put(buf *buffer) { + // Remove reference to the underlying storage, allowing it to be garbage + // collected. + buf.data = nil +} + +// setBufferSize sets the size of underlying storage buffer for future +// allocations. It can be called at any time. +func (p *pool) setBufferSize(size int) { + p.bufferSize = size +} + +// afterLoad is invoked by stateify. +func (p *pool) afterLoad() { + // S/R does not save subslice into embeddedStorage correctly. Restore + // available portion of embeddedStorage manually. Restore as nil if none used. + for i := len(p.embeddedStorage); i > 0; i-- { + if p.embeddedStorage[i-1].data != nil { + p.avail = p.embeddedStorage[i:] + break + } + } +} diff --git a/pkg/buffer/pool_test.go b/pkg/buffer/pool_test.go new file mode 100644 index 000000000..8584bac89 --- /dev/null +++ b/pkg/buffer/pool_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" +) + +func TestGetDefaultBufferSize(t *testing.T) { + var p pool + for i := 0; i < embeddedCount*2; i++ { + buf := p.get() + if got, want := len(buf.data), defaultBufferSize; got != want { + t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) + } + } +} + +func TestGetCustomBufferSize(t *testing.T) { + const size = 100 + + var p pool + p.setBufferSize(size) + for i := 0; i < embeddedCount*2; i++ { + buf := p.get() + if got, want := len(buf.data), size; got != want { + t.Errorf("#%d len(buf.data) = %d, want %d", i, got, want) + } + } +} + +func TestPut(t *testing.T) { + var p pool + buf := p.get() + p.put(buf) + if buf.data != nil { + t.Errorf("buf.data = %x, want nil", buf.data) + } +} diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go index b789e56e9..8b42575b4 100644 --- a/pkg/buffer/safemem.go +++ b/pkg/buffer/safemem.go @@ -44,7 +44,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e // Need at least one buffer. firstBuf := v.data.Back() if firstBuf == nil { - firstBuf = bufferPool.Get().(*buffer) + firstBuf = v.pool.get() v.data.PushBack(firstBuf) } @@ -56,7 +56,7 @@ func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, e count -= l blocks = append(blocks, firstBuf.WriteBlock()) for count > 0 { - emptyBuf := bufferPool.Get().(*buffer) + emptyBuf := v.pool.get() v.data.PushBack(emptyBuf) block := emptyBuf.WriteBlock().TakeFirst64(count) count -= uint64(block.Len()) diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go index 47f357e0c..721cc5934 100644 --- a/pkg/buffer/safemem_test.go +++ b/pkg/buffer/safemem_test.go @@ -23,6 +23,8 @@ import ( ) func TestSafemem(t *testing.T) { + const bufferSize = defaultBufferSize + testCases := []struct { name string input string diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index e6901eadb..00652d675 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -27,6 +27,7 @@ import ( type View struct { data bufferList size int64 + pool pool } // TrimFront removes the first count bytes from the buffer. @@ -81,7 +82,7 @@ func (v *View) advanceRead(count int64) { buf = buf.Next() // Iterate. v.data.Remove(oldBuf) oldBuf.Reset() - bufferPool.Put(oldBuf) + v.pool.put(oldBuf) // Update counts. count -= sz @@ -118,7 +119,7 @@ func (v *View) Truncate(length int64) { // Drop the buffer completely; see above. v.data.Remove(buf) buf.Reset() - bufferPool.Put(buf) + v.pool.put(buf) v.size -= sz } } @@ -137,7 +138,7 @@ func (v *View) Grow(length int64, zero bool) { // Is there some space in the last buffer? if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } @@ -181,7 +182,7 @@ func (v *View) Prepend(data []byte) { for len(data) > 0 { // Do we need an empty buffer? - buf := bufferPool.Get().(*buffer) + buf := v.pool.get() v.data.PushFront(buf) // The buffer is empty; copy last chunk. @@ -211,7 +212,7 @@ func (v *View) Append(data []byte) { // Ensure there's a buffer with space. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } @@ -297,7 +298,7 @@ func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { // Ensure we have an empty buffer. if buf == nil || buf.Full() { - buf = bufferPool.Get().(*buffer) + buf = v.pool.get() v.data.PushBack(buf) } diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 3db1bc6ee..839af0223 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -16,11 +16,16 @@ package buffer import ( "bytes" + "context" "io" "strings" "testing" + + "gvisor.dev/gvisor/pkg/state" ) +const bufferSize = defaultBufferSize + func fillAppend(v *View, data []byte) { v.Append(data) } @@ -50,6 +55,30 @@ var fillFuncs = map[string]func(*View, []byte){ "writeFromReaderEnd": fillWriteFromReaderEnd, } +func BenchmarkReadAt(b *testing.B) { + b.ReportAllocs() + var v View + v.Append(make([]byte, 100)) + + buf := make([]byte, 10) + for i := 0; i < b.N; i++ { + v.ReadAt(buf, 0) + } +} + +func BenchmarkWriteRead(b *testing.B) { + b.ReportAllocs() + var v View + sz := 1000 + wbuf := make([]byte, sz) + rbuf := bytes.NewBuffer(make([]byte, sz)) + for i := 0; i < b.N; i++ { + v.Append(wbuf) + rbuf.Reset() + v.ReadToWriter(rbuf, int64(sz)) + } +} + func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) { t.Helper() d := make([]byte, n) @@ -465,3 +494,51 @@ func TestView(t *testing.T) { } } } + +func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { + t.Helper() + var buf bytes.Buffer + ctx := context.Background() + if _, err := state.Save(ctx, &buf, toSave); err != nil { + t.Fatal("state.Save:", err) + } + if _, err := state.Load(ctx, bytes.NewReader(buf.Bytes()), toLoad); err != nil { + t.Fatal("state.Load:", err) + } +} + +func TestSaveRestoreViewEmpty(t *testing.T) { + var toSave View + var v View + doSaveAndLoad(t, &toSave, &v) + + if got := v.pool.avail; got != nil { + t.Errorf("pool is not in zero state: v.pool.avail = %v, want nil", got) + } + if got := v.Flatten(); len(got) != 0 { + t.Errorf("v.Flatten() = %x, want []", got) + } +} + +func TestSaveRestoreView(t *testing.T) { + // Create data that fits 2.5 slots. + data := bytes.Join([][]byte{ + bytes.Repeat([]byte{1, 2}, defaultBufferSize), + bytes.Repeat([]byte{3}, defaultBufferSize/2), + }, nil) + + var toSave View + toSave.Append(data) + + var v View + doSaveAndLoad(t, &toSave, &v) + + // Next available slot at index 3; 0-2 slot are used. + i := 3 + if got, want := &v.pool.avail[0], &v.pool.embeddedStorage[i]; got != want { + t.Errorf("next available buffer points to %p, want %p (&v.pool.embeddedStorage[%d])", got, want, i) + } + if got := v.Flatten(); !bytes.Equal(got, data) { + t.Errorf("v.Flatten() = %x, want %x", got, data) + } +} diff --git a/pkg/context/BUILD b/pkg/context/BUILD index 239f31149..f33e23bf7 100644 --- a/pkg/context/BUILD +++ b/pkg/context/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["context.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/amutex", "//pkg/log", ], ) diff --git a/pkg/context/context.go b/pkg/context/context.go index 5319b6d8d..2613bc752 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -26,7 +26,6 @@ import ( "context" "time" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -68,9 +67,10 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { // In both cases, values extracted from the Context should be used instead. type Context interface { log.Logger - amutex.Sleeper context.Context + ChannelSleeper + // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate // is true and the Context represents a Task, the Task's AddressSpace is @@ -85,29 +85,60 @@ type Context interface { UninterruptibleSleepFinish(activate bool) } -// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep -// methods for anonymous embedding in other types that do not implement sleeps. -type NoopSleeper struct { - amutex.NoopSleeper +// A ChannelSleeper represents a goroutine that may sleep interruptibly, where +// interruption is indicated by a channel becoming readable. +type ChannelSleeper interface { + // SleepStart is called before going to sleep interruptibly. If SleepStart + // returns a non-nil channel and that channel becomes ready for receiving + // while the goroutine is sleeping, the goroutine should be woken, and + // SleepFinish(false) should be called. Otherwise, SleepFinish(true) should + // be called after the goroutine stops sleeping. + SleepStart() <-chan struct{} + + // SleepFinish is called after an interruptibly-sleeping goroutine stops + // sleeping, as documented by SleepStart. + SleepFinish(success bool) + + // Interrupted returns true if the channel returned by SleepStart is + // ready for receiving. + Interrupted() bool +} + +// NoopSleeper is a noop implementation of ChannelSleeper and +// Context.UninterruptibleSleep* methods for anonymous embedding in other types +// that do not implement special behavior around sleeps. +type NoopSleeper struct{} + +// SleepStart implements ChannelSleeper.SleepStart. +func (NoopSleeper) SleepStart() <-chan struct{} { + return nil +} + +// SleepFinish implements ChannelSleeper.SleepFinish. +func (NoopSleeper) SleepFinish(success bool) {} + +// Interrupted implements ChannelSleeper.Interrupted. +func (NoopSleeper) Interrupted() bool { + return false } -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} +// UninterruptibleSleepStart implements Context.UninterruptibleSleepStart. +func (NoopSleeper) UninterruptibleSleepStart(deactivate bool) {} -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} +// UninterruptibleSleepFinish implements Context.UninterruptibleSleepFinish. +func (NoopSleeper) UninterruptibleSleepFinish(activate bool) {} -// Deadline returns zero values, meaning no deadline. +// Deadline implements context.Context.Deadline. func (NoopSleeper) Deadline() (time.Time, bool) { return time.Time{}, false } -// Done returns nil. +// Done implements context.Context.Done. func (NoopSleeper) Done() <-chan struct{} { return nil } -// Err returns nil. +// Err returns context.Context.Err. func (NoopSleeper) Err() error { return nil } diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index 6831adcce..a4f4b2c5e 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -100,12 +100,9 @@ var coveragePool = sync.Pool{ // instrumentation_filter. // // Note that we "consume", i.e. clear, coverdata when this function is run, to -// ensure that each event is only reported once. -// -// TODO(b/160639712): evaluate whether it is ok to reset the global coverage -// data every time this function is run. We could technically have each thread -// store a local snapshot against which we compare the most recent coverdata so -// that separate threads do not affect each other's view of the data. +// ensure that each event is only reported once. Due to the limitations of Go +// coverage tools, we reset the global coverage data every time this function is +// run. func ConsumeCoverageData(w io.Writer) int { once.Do(initCoverageData) @@ -117,23 +114,23 @@ func ConsumeCoverageData(w io.Writer) int { for fileIndex, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { - val := atomic.SwapUint32(&counters[index], 0) - if val != 0 { - // Calculate the synthetic PC. - pc := globalData.syntheticPCs[fileIndex][index] - - usermem.ByteOrder.PutUint64(pcBuffer[:], pc) - n, err := w.Write(pcBuffer[:]) - if err != nil { - if err == io.EOF { - // Simply stop writing if we encounter EOF; it's ok if we attempted to - // write more than we can hold. - return total + n - } - panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err)) + if atomic.LoadUint32(&counters[index]) == 0 { + continue + } + // Non-zero coverage data found; consume it and report as a PC. + atomic.StoreUint32(&counters[index], 0) + pc := globalData.syntheticPCs[fileIndex][index] + usermem.ByteOrder.PutUint64(pcBuffer[:], pc) + n, err := w.Write(pcBuffer[:]) + if err != nil { + if err == io.EOF { + // Simply stop writing if we encounter EOF; it's ok if we attempted to + // write more than we can hold. + return total + n } - total += n + panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err)) } + total += n } } diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 83bcfe220..cc6b0cdf1 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -49,7 +49,7 @@ func fixCount(n int, err error) (int, error) { // Read implements io.Reader. func (r *ReadWriter) Read(b []byte) (int, error) { - c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b)) + c, err := fixCount(syscall.Read(r.FD(), b)) if c == 0 && len(b) > 0 && err == nil { return 0, io.EOF } @@ -62,7 +62,7 @@ func (r *ReadWriter) Read(b []byte) (int, error) { func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pread(r.FD(), b, off)) if m == 0 && err == nil { return c, io.EOF } @@ -82,7 +82,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { var n, remaining int for remaining = len(b); remaining > 0; { woff := len(b) - remaining - n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:]) + n, err = syscall.Write(r.FD(), b[woff:]) if n > 0 { // syscall.Write wrote some bytes. This is the common case. @@ -110,7 +110,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pwrite(r.FD(), b, off)) if err != nil { break } @@ -121,6 +121,16 @@ func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { return } +// FD returns the owned file descriptor. Ownership remains unchanged. +func (r *ReadWriter) FD() int { + return int(atomic.LoadInt64(&r.fd)) +} + +// String implements Stringer.String(). +func (r *ReadWriter) String() string { + return fmt.Sprintf("FD: %d", r.FD()) +} + // FD owns a host file descriptor. // // It is similar to os.File, with a few important distinctions: @@ -167,6 +177,23 @@ func NewFromFile(file *os.File) (*FD, error) { return New(fd), nil } +// NewFromFiles creates new FDs for each file in the slice. +func NewFromFiles(files []*os.File) ([]*FD, error) { + rv := make([]*FD, 0, len(files)) + for _, f := range files { + new, err := NewFromFile(f) + if err != nil { + // Cleanup on error. + for _, fd := range rv { + fd.Close() + } + return nil, err + } + rv = append(rv, new) + } + return rv, nil +} + // Open is equivalent to open(2). func Open(path string, openmode int, perm uint32) (*FD, error) { f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm) @@ -204,11 +231,6 @@ func (f *FD) Release() int { return int(atomic.SwapInt64(&f.fd, -1)) } -// FD returns the file descriptor owned by FD. FD retains ownership. -func (f *FD) FD() int { - return int(atomic.LoadInt64(&f.fd)) -} - // File converts the FD to an os.File. // // FD does not transfer ownership of the file descriptor (it will be @@ -219,7 +241,7 @@ func (f *FD) FD() int { // This operation is somewhat expensive, so care should be taken to minimize // its use. func (f *FD) File() (*os.File, error) { - fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd))) + fd, err := syscall.Dup(f.FD()) if err != nil { return nil, err } diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md new file mode 100644 index 000000000..51d0d40e5 --- /dev/null +++ b/pkg/lisafs/README.md @@ -0,0 +1,363 @@ +# Replacing 9P + +## Background + +The Linux filesystem model consists of the following key aspects (modulo mounts, +which are outside the scope of this discussion): + +- A `struct inode` represents a "filesystem object", such as a directory or a + regular file. "Filesystem object" is most precisely defined by the practical + properties of an inode, such as an immutable type (regular file, directory, + symbolic link, etc.) and its independence from the path originally used to + obtain it. + +- A `struct dentry` represents a node in a filesystem tree. Semantically, each + dentry is immutably associated with an inode representing the filesystem + object at that position. (Linux implements optimizations involving reuse of + unreferenced dentries, which allows their associated inodes to change, but + this is outside the scope of this discussion.) + +- A `struct file` represents an open file description (hereafter FD) and is + needed to perform I/O. Each FD is immutably associated with the dentry + through which it was opened. + +The current gVisor virtual filesystem implementation (hereafter VFS1) closely +imitates the Linux design: + +- `struct inode` => `fs.Inode` + +- `struct dentry` => `fs.Dirent` + +- `struct file` => `fs.File` + +gVisor accesses most external filesystems through a variant of the 9P2000.L +protocol, including extensions for performance (`walkgetattr`) and for features +not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family +is inode-based; 9P fids represent a file (equivalently "file system object"), +and the protocol is structured around alternatively obtaining fids to represent +files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on +those fids. + +In the sections below, a **shared** filesystem is a filesystem that is *mutably* +accessible by multiple concurrent clients, such that a **non-shared** filesystem +is a filesystem that is either read-only or accessible by only a single client. + +## Problems + +### Serialization of Path Component RPCs + +Broadly speaking, VFS1 traverses each path component in a pathname, alternating +between verifying that each traversed dentry represents an inode that represents +a searchable directory and moving to the next dentry in the path. + +In the context of a remote filesystem, the structure of this traversal means +that - modulo caching - a path involving N components requires at least N-1 +*sequential* RPCs to obtain metadata for intermediate directories, incurring +significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk` +and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On +non-shared filesystems, this overhead is primarily significant during +application startup; caching mitigates much of this overhead at steady state. On +shared filesystems, where correct caching requires revalidation (requiring RPCs +for each revalidated directory anyway), this overhead is consistently ruinous. + +### Inefficient RPCs + +9P is not exceptionally economical with RPCs in general. In addition to the +issue described above: + +- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce + an unopened fid representing the file, and `lopen` to open the fid. + +- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened + fid representing the parent directory, and `lcreate` to create the file and + convert the fid to an open fid representing the created file. In practice, + both the Linux and gVisor 9P clients expect to have an unopened fid for the + created file (necessitating an additional `walk`), as well as attributes for + the created file (necessitating an additional `getattr`), for a total of 4 + RPCs. (In a shared filesystem, where whether a file already exists can + change between RPCs, a correct implementation of `open(O_CREAT)` would have + to alternate between these two paths (plus `clunk`ing the temporary fid + between alternations, since the nature of the `fid` differs between the two + paths). Neither Linux nor gVisor implement the required alternation, so + `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.) + +- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC + asynchronously in an attempt to reduce critical path latency, but scheduling + overhead makes this not clearly advantageous in practice. + +- `read` and `readdir` can return partial reads without a way to indicate EOF, + necessitating an additional final read to detect EOF. + +- Operations that affect filesystem state do not consistently return updated + filesystem state. In gVisor, the client implementation attempts to handle + this by tracking what it thinks updated state "should" be; this is complex, + and especially brittle for timestamps (which are often not arbitrarily + settable). In Linux, the client implemtation invalidates cached metadata + whenever it performs such an operation, and reloads it when a dentry + corresponding to an inode with no valid cached metadata is revalidated; this + is simple, but necessitates an additional `getattr`. + +### Dentry/Inode Ambiguity + +As noted above, 9P's documentation tends to imply that unopened fids represent +an inode. In practice, most filesystem APIs present very limited interfaces for +working with inodes at best, such that the interpretation of unopened fids +varies: + +- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When + caching is enabled, it also associates each inode with the first fid opened + writably that references that inode, in order to support page cache + writeback. + +- gVisor's 9P client associates unopened fids with inodes, and also caches + opened fids in inodes in a manner similar to Linux. + +- The runsc fsgofer associates unopened fids with both "dentries" (host + filesystem paths) and "inodes" (host file descriptors); which is used + depends on the operation invoked on the fid. + +For non-shared filesystems, this confusion has resulted in correctness issues +that are (in gVisor) currently handled by a number of coarse-grained locks that +serialize renames with all other filesystem operations. For shared filesystems, +this means inconsistent behavior in the presence of concurrent mutation. + +## Design + +Almost all Linux filesystem syscalls describe filesystem resources in one of two +ways: + +- Path-based: A filesystem position is described by a combination of a + starting position and a sequence of path components relative to that + position, where the starting position is one of: + + - The VFS root (defined by mount namespace and chroot), for absolute paths + + - The VFS position of an existing FD, for relative paths passed to `*at` + syscalls (e.g. `statat`) + + - The current working directory, for relative paths passed to non-`*at` + syscalls and `*at` syscalls with `AT_FDCWD` + +- File-description-based: A filesystem object is described by an existing FD, + passed to a `f*` syscall (e.g. `fstat`). + +Many of our issues with 9P arise from its (and VFS') interposition of a model +based on inodes between the filesystem syscall API and filesystem +implementations. We propose to replace 9P with a protocol that does not feature +inodes at all, and instead closely follows the filesystem syscall API by +featuring only path-based and FD-based operations, with minimal deviations as +necessary to ameliorate deficiencies in the syscall interface (see below). This +approach addresses the issues described above: + +- Even on shared filesystems, most application filesystem syscalls are + translated to a single RPC (possibly excepting special cases described + below), which is a logical lower bound. + +- The behavior of application syscalls on shared filesystems is + straightforwardly predictable: path-based syscalls are translated to + path-based RPCs, which will re-lookup the file at that path, and FD-based + syscalls are translated to FD-based RPCs, which use an existing open file + without performing another lookup. (This is at least true on gofers that + proxy the host local filesystem; other filesystems that lack support for + e.g. certain operations on FDs may have different behavior, but this + divergence is at least still predictable and inherent to the underlying + filesystem implementation.) + +Note that this approach is only feasible in gVisor's next-generation virtual +filesystem (VFS2), which does not assume the existence of inodes and allows the +remote filesystem client to translate whole path-based syscalls into RPCs. Thus +one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the +inability to construct a Linux client that is performance-competitive with +gVisor. + +### File Permissions + +Many filesystem operations are side-effectual, such that file permissions must +be checked before such operations take effect. The simplest approach to file +permission checking is for the sentry to obtain permissions from the remote +filesystem, then apply permission checks in the sentry before performing the +application-requested operation. However, this requires an additional RPC per +application syscall (which can't be mitigated by caching on shared filesystems). +Alternatively, we may delegate file permission checking to gofers. In general, +file permission checks depend on the following properties of the accessor: + +- Filesystem UID/GID + +- Supplementary GIDs + +- Effective capabilities in the accessor's user namespace (i.e. the accessor's + effective capability set) + +- All UIDs and GIDs mapped in the accessor's user namespace (which determine + if the accessor's capabilities apply to accessed files) + +We may choose to delay implementation of file permission checking delegation, +although this is potentially costly since it doubles the number of required RPCs +for most operations on shared filesystems. We may also consider compromise +options, such as only delegating file permission checks for accessors in the +root user namespace. + +### Symbolic Links + +gVisor usually interprets symbolic link targets in its VFS rather than on the +filesystem containing the symbolic link; thus e.g. a symlink to +"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's +procfs rather than the host's. This implies that: + +- Remote filesystem servers that proxy filesystems supporting symlinks must + check if each path component is a symlink during path traversal. + +- Absolute symlinks require that the sentry restart the operation at its + contextual VFS root (which is task-specific and may not be on a remote + filesystem at all), so if a remote filesystem server encounters an absolute + symlink during path traversal on behalf of a path-based operation, it must + terminate path traversal and return the symlink target. + +- Relative symlinks begin target resolution in the parent directory of the + symlink, so in theory most relative symlinks can be handled automatically + during the path traversal that encounters the symlink, provided that said + traversal is supplied with the number of remaining symlinks before `ELOOP`. + However, the new path traversed by the symlink target may cross VFS mount + boundaries, such that it's only safe for remote filesystem servers to + speculatively follow relative symlinks for side-effect-free operations such + as `stat` (where the sentry can simply ignore results that are inapplicable + due to crossing mount boundaries). We may choose to delay implementation of + this feature, at the cost of an additional RPC per relative symlink (note + that even if the symlink target crosses a mount boundary, the sentry will + need to `stat` the path to the mount boundary to confirm that each traversed + component is an accessible directory); until it is implemented, relative + symlinks may be handled like absolute symlinks, by terminating path + traversal and returning the symlink target. + +The possibility of symlinks (and the possibility of a compromised sentry) means +that the sentry may issue RPCs with paths that, in the absence of symlinks, +would traverse beyond the root of the remote filesystem. For example, the sentry +may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is +a symlink then the resulting path may be elsewhere on the remote filesystem. To +handle this, path traversal must also track its current depth below the remote +filesystem root, and terminate path traversal if it would ascend beyond this +point. + +### Path Traversal + +Since path-based VFS operations will translate to path-based RPCs, filesystem +servers will need to handle path traversal. From the perspective of a given +filesystem implementation in the server, there are two basic approaches to path +traversal: + +- Inode-walk: For each path component, obtain a handle to the underlying + filesystem object (e.g. with `open(O_PATH)`), check if that object is a + symlink (as described above) and that that object is accessible by the + caller (e.g. with `fstat()`), then continue to the next path component (e.g. + with `openat()`). This ensures that the checked filesystem object is the one + used to obtain the next object in the traversal, which is intuitively + appealing. However, while this approach works for host local filesystems, it + requires features that are not widely supported by other filesystems. + +- Path-walk: For each path component, use a path-based operation to determine + if the filesystem object currently referred to by that path component is a + symlink / is accessible. This is highly portable, but suffers from quadratic + behavior (at the level of the underlying filesystem implementation, the + first path component will be traversed a number of times equal to the number + of path components in the path). + +The implementation should support either option by delegating path traversal to +filesystem implementations within the server (like VFS and the remote filesystem +protocol itself), as inode-walking is still safe, efficient, amenable to FD +caching, and implementable on non-shared host local filesystems (a sufficiently +common case as to be worth considering in the design). + +Both approaches are susceptible to race conditions that may permit sandboxed +filesystem escapes: + +- Under inode-walk, a malicious application may cause a directory to be moved + (with `rename`) during path traversal, such that the filesystem + implementation incorrectly determines whether subsequent inodes are located + in paths that should be visible to sandboxed applications. + +- Under path-walk, a malicious application may cause a non-symlink file to be + replaced with a symlink during path traversal, such that following path + operations will incorrectly follow the symlink. + +Both race conditions can, to some extent, be mitigated in filesystem server +implementations by synchronizing path traversal with the hazardous operations in +question. However, shared filesystems are frequently used to share data between +sandboxed and unsandboxed applications in a controlled way, and in some cases a +malicious sandboxed application may be able to take advantage of a hazardous +filesystem operation performed by an unsandboxed application. In some cases, +filesystem features may be available to ensure safety even in such cases (e.g. +[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)), +but it is not clear how to solve this problem in general. (Note that this issue +is not specific to our design; rather, it is a fundamental limitation of +filesystem sandboxing.) + +### Filesystem Multiplexing + +A given sentry may need to access multiple distinct remote filesystems (e.g. +different volumes for a given container). In many cases, there is no advantage +to serving these filesystems from distinct filesystem servers, or accessing them +through distinct connections (factors such as maximum RPC concurrency should be +based on available host resources). Therefore, the protocol should support +multiplexing of distinct filesystem trees within a single session. 9P supports +this by allowing multiple calls to the `attach` RPC to produce fids representing +distinct filesystem trees, but this is somewhat clunky; we propose a much +simpler mechanism wherein each message that conveys a path also conveys a +numeric filesystem ID that identifies a filesystem tree. + +## Alternatives Considered + +### Additional Extensions to 9P + +There are at least three conceptual aspects to 9P: + +- Wire format: messages with a 4-byte little-endian size prefix, strings with + a 2-byte little-endian size prefix, etc. Whether the wire format is worth + retaining is unclear; in particular, it's unclear that the 9P wire format + has a significant advantage over protobufs, which are substantially easier + to extend. Note that the official Go protobuf implementation is widely known + to suffer from a significant number of performance deficiencies, so if we + choose to switch to protobuf, we may need to use an alternative toolchain + such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g. + by Kubernetes). + +- Filesystem model: fids, qids, etc. Discarding this is one of the motivations + for this proposal. + +- RPCs: Twalk, Tlopen, etc. In addition to previously-described + inefficiencies, most of these are dependent on the filesystem model and + therefore must be discarded. + +### FUSE + +The FUSE (Filesystem in Userspace) protocol is frequently used to provide +arbitrary userspace filesystem implementations to a host Linux kernel. +Unfortunately, FUSE is also inode-based, and therefore doesn't address any of +the problems we have with 9P. + +### virtio-fs + +virtio-fs is an ongoing project aimed at improving Linux VM filesystem +performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it +is based on: + +- Using a FUSE client in the guest that communicates over virtio with a FUSE + server in the host. + +- Using DAX to map the host page cache into the guest. + +- Using a file metadata table in shared memory to avoid VM exits for metadata + updates. + +None of these improvements seem applicable to gVisor: + +- As explained above, FUSE is still inode-based, so it is still susceptible to + most of the problems we have with 9P. + +- Our use of host file descriptors already allows us to leverage the host page + cache for file contents. + +- Our need for shared filesystem coherence is usually based on a user + requirement that an out-of-sandbox filesystem mutation is guaranteed to be + visible by all subsequent observations from within the sandbox, or vice + versa; it's not clear that this can be guaranteed without a synchronous + signaling mechanism like an RPC. diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD new file mode 100644 index 000000000..aac0161fa --- /dev/null +++ b/pkg/marshal/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "marshal", + srcs = [ + "marshal.go", + "marshal_impl_util.go", + ], + visibility = [ + "//:sandbox", + ], + deps = ["//pkg/usermem"], +) diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go new file mode 100644 index 000000000..d8cb44b40 --- /dev/null +++ b/pkg/marshal/marshal.go @@ -0,0 +1,184 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal defines the Marshallable interface for +// serialize/deserializing go data structures to/from memory, according to the +// Linux ABI. +// +// Implementations of this interface are typically automatically generated by +// tools/go_marshal. See the go_marshal README for details. +package marshal + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// CopyContext defines the memory operations required to marshal to and from +// user memory. Typically, kernel.Task is used to provide implementations for +// these operations. +type CopyContext interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + +// Marshallable represents operations on a type that can be marshalled to and +// from memory. +// +// go-marshal automatically generates implementations for this interface for +// types marked as '+marshal'. +type Marshallable interface { + io.WriterTo + + // SizeBytes is the size of the memory representation of a type in + // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). + SizeBytes() int + + // MarshalBytes serializes a copy of a type to dst. + // Precondition: dst must be at least SizeBytes() in length. + MarshalBytes(dst []byte) + + // UnmarshalBytes deserializes a type from src. + // Precondition: src must be at least SizeBytes() in length. + UnmarshalBytes(src []byte) + + // Packed returns true if the marshalled size of the type is the same as the + // size it occupies in memory. This happens when the type has no fields + // starting at unaligned addresses (should always be true by default for ABI + // structs, verified by automatically generated tests when using + // go_marshal), and has no fields marked `marshal:"unaligned"`. + // + // Packed must return the same result for all possible values of the type + // implementing it. Violating this constraint implies the type doesn't have + // a static memory layout, and will lead to memory corruption. + // Go-marshal-generated code reuses the result of Packed for multiple values + // of the same type. + Packed() bool + + // MarshalUnsafe serializes a type by bulk copying its in-memory + // representation to the dst buffer. This is only safe to do when the type + // has no implicit padding, see Marshallable.Packed. When Packed would + // return false, MarshalUnsafe should fall back to the safer but slower + // MarshalBytes. + // Precondition: dst must be at least SizeBytes() in length. + MarshalUnsafe(dst []byte) + + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. + // + // This allows much faster unmarshalling of types which have no implicit + // padding, see Marshallable.Packed. When Packed would return false, + // UnmarshalUnsafe should fall back to the safer but slower unmarshal + // mechanism implemented in UnmarshalBytes. + // Precondition: src must be at least SizeBytes() in length. + UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + // + // If the copy-in from the task memory is only partially successful, CopyIn + // should still attempt to deserialize as much data as possible. See comment + // for UnmarshalBytes. + CopyIn(cc CopyContext, addr usermem.Addr) (int, error) + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + // + // The copy-out to the task memory may be partially successful, in which + // case CopyOut returns how much data was serialized. See comment for + // MarshalBytes for implications. + CopyOut(cc CopyContext, addr usermem.Addr) (int, error) + + // CopyOutN is like CopyOut, but explicitly requests a partial + // copy-out. Note that this may yield unexpected results for non-packed + // types and the caller may only want to allow this for packed types. See + // comment on MarshalBytes. + // + // The limit must be less than or equal to SizeBytes(). + CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) +} + +// go-marshal generates additional functions for a type based on additional +// clauses to the +marshal directive. They are documented below. +// +// Slice API +// ========= +// +// Adding a "slice" clause to the +marshal directive for structs or newtypes on +// primitives like this: +// +// // +marshal slice:FooSlice +// type Foo struct { ... } +// +// Generates four additional functions for marshalling slices of Foos like this: +// +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.MarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: dst must be at least len(src)*Foo.SizeBytes() in length. +// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } +// +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.UnmarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: src must be at least len(dst)*Foo.SizeBytes() in length. +// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } +// +// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. +// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... } +// +// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. +// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... } +// +// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where +// %s is the first argument to the slice clause. This directive is not supported +// for newtypes on arrays. +// +// The slice clause also takes an optional second argument, which must be the +// value "inner": +// +// // +marshal slice:Int32Slice:inner +// type Int32 int32 +// +// This is only valid on newtypes on primitives, and causes the generated +// functions to accept slices of the inner type instead: +// +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... } +// +// Without "inner", they would instead be: +// +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... } +// +// This may help avoid a cast depending on how the generated functions are used. diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go new file mode 100644 index 000000000..ea75e09f2 --- /dev/null +++ b/pkg/marshal/marshal_impl_util.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package marshal + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// StubMarshallable implements the Marshallable interface. +// StubMarshallable is a convenient embeddable type for satisfying the +// marshallable interface, but provides no actual implementation. It is +// useful when the marshallable interface needs to be implemented manually, +// but the caller doesn't require the full marshallable interface. +type StubMarshallable struct{} + +// WriteTo implements Marshallable.WriteTo. +func (StubMarshallable) WriteTo(w io.Writer) (n int64, err error) { + panic("Please implement your own WriteTo function") +} + +// SizeBytes implements Marshallable.SizeBytes. +func (StubMarshallable) SizeBytes() int { + panic("Please implement your own SizeBytes function") +} + +// MarshalBytes implements Marshallable.MarshalBytes. +func (StubMarshallable) MarshalBytes(dst []byte) { + panic("Please implement your own MarshalBytes function") +} + +// UnmarshalBytes implements Marshallable.UnmarshalBytes. +func (StubMarshallable) UnmarshalBytes(src []byte) { + panic("Please implement your own UnmarshalBytes function") +} + +// Packed implements Marshallable.Packed. +func (StubMarshallable) Packed() bool { + panic("Please implement your own Packed function") +} + +// MarshalUnsafe implements Marshallable.MarshalUnsafe. +func (StubMarshallable) MarshalUnsafe(dst []byte) { + panic("Please implement your own MarshalUnsafe function") +} + +// UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe. +func (StubMarshallable) UnmarshalUnsafe(src []byte) { + panic("Please implement your own UnmarshalUnsafe function") +} + +// CopyIn implements Marshallable.CopyIn. +func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) { + panic("Please implement your own CopyIn function") +} + +// CopyOut implements Marshallable.CopyOut. +func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) { + panic("Please implement your own CopyOut function") +} + +// CopyOutN implements Marshallable.CopyOutN. +func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) { + panic("Please implement your own CopyOutN function") +} diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD new file mode 100644 index 000000000..d77a11c79 --- /dev/null +++ b/pkg/marshal/primitive/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "primitive", + srcs = [ + "primitive.go", + ], + marshal = True, + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/context", + "//pkg/marshal", + "//pkg/usermem", + ], +) diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go new file mode 100644 index 000000000..4b342de6b --- /dev/null +++ b/pkg/marshal/primitive/primitive.go @@ -0,0 +1,349 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package primitive defines marshal.Marshallable implementations for primitive +// types. +package primitive + +import ( + "io" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Int8 is a marshal.Marshallable implementation for int8. +// +// +marshal slice:Int8Slice:inner +type Int8 int8 + +// Uint8 is a marshal.Marshallable implementation for uint8. +// +// +marshal slice:Uint8Slice:inner +type Uint8 uint8 + +// Int16 is a marshal.Marshallable implementation for int16. +// +// +marshal slice:Int16Slice:inner +type Int16 int16 + +// Uint16 is a marshal.Marshallable implementation for uint16. +// +// +marshal slice:Uint16Slice:inner +type Uint16 uint16 + +// Int32 is a marshal.Marshallable implementation for int32. +// +// +marshal slice:Int32Slice:inner +type Int32 int32 + +// Uint32 is a marshal.Marshallable implementation for uint32. +// +// +marshal slice:Uint32Slice:inner +type Uint32 uint32 + +// Int64 is a marshal.Marshallable implementation for int64. +// +// +marshal slice:Int64Slice:inner +type Int64 int64 + +// Uint64 is a marshal.Marshallable implementation for uint64. +// +// +marshal slice:Uint64Slice:inner +type Uint64 uint64 + +// ByteSlice is a marshal.Marshallable implementation for []byte. +// This is a convenience wrapper around a dynamically sized type, and can't be +// embedded in other marshallable types because it breaks assumptions made by +// go-marshal internals. It violates the "no dynamically-sized types" +// constraint of the go-marshal library. +type ByteSlice []byte + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *ByteSlice) SizeBytes() int { + return len(*b) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *ByteSlice) MarshalBytes(dst []byte) { + copy(dst, *b) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *ByteSlice) UnmarshalBytes(src []byte) { + copy(*b, src) +} + +// Packed implements marshal.Marshallable.Packed. +func (b *ByteSlice) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *ByteSlice) MarshalUnsafe(dst []byte) { + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *ByteSlice) UnmarshalUnsafe(src []byte) { + b.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyInBytes(addr, *b) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyOutBytes(addr, *b) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + return cc.CopyOutBytes(addr, (*b)[:limit]) +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(*b) + return int64(n), err +} + +var _ marshal.Marshallable = (*ByteSlice)(nil) + +// Below, we define some convenience functions for marshalling primitive types +// using the newtypes above, without requiring superfluous casts. + +// 8-bit integers + +// CopyInt8In is a convenient wrapper for copying in an int8 from the task's +// memory. +func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) { + var buf Int8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int8(buf) + return n, nil +} + +// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's +// memory. +func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) { + srcP := Int8(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's +// memory. +func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) { + var buf Uint8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint8(buf) + return n, nil +} + +// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's +// memory. +func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) { + srcP := Uint8(src) + return srcP.CopyOut(cc, addr) +} + +// 16-bit integers + +// CopyInt16In is a convenient wrapper for copying in an int16 from the task's +// memory. +func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) { + var buf Int16 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int16(buf) + return n, nil +} + +// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's +// memory. +func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) { + srcP := Int16(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's +// memory. +func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) { + var buf Uint16 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint16(buf) + return n, nil +} + +// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's +// memory. +func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) { + srcP := Uint16(src) + return srcP.CopyOut(cc, addr) +} + +// 32-bit integers + +// CopyInt32In is a convenient wrapper for copying in an int32 from the task's +// memory. +func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) { + var buf Int32 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int32(buf) + return n, nil +} + +// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's +// memory. +func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) { + srcP := Int32(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's +// memory. +func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) { + var buf Uint32 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint32(buf) + return n, nil +} + +// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's +// memory. +func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) { + srcP := Uint32(src) + return srcP.CopyOut(cc, addr) +} + +// 64-bit integers + +// CopyInt64In is a convenient wrapper for copying in an int64 from the task's +// memory. +func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) { + var buf Int64 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int64(buf) + return n, nil +} + +// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's +// memory. +func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) { + srcP := Int64(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's +// memory. +func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) { + var buf Uint64 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint64(buf) + return n, nil +} + +// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's +// memory. +func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) { + srcP := Uint64(src) + return srcP.CopyOut(cc, addr) +} + +// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the +// task's memory. +func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = []byte(buf) + return n, nil +} + +// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the +// task's memory. +func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// CopyStringIn is a convenient wrapper for copying in a string from the +// task's memory. +func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = string(buf) + return n, nil +} + +// CopyStringOut is a convenient wrapper for copying out a string to the task's +// memory. +func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// IOCopyContext wraps an object implementing usermem.IO to implement +// marshal.CopyContext. +type IOCopyContext struct { + Ctx context.Context + IO usermem.IO + Opts usermem.IOOpts +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 36832ec86..4b4f9bd52 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -225,9 +225,9 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree // Verify will modify the cursor for data, but always restores it to its // original position upon exit. The cursor for tree is modified and not // restored. -func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) error { +func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) { if readSize <= 0 { - return fmt.Errorf("Unexpected read size: %d", readSize) + return 0, fmt.Errorf("Unexpected read size: %d", readSize) } layout := InitLayout(int64(dataSize), dataAndTreeInSameFile) @@ -240,29 +240,30 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in // finishes. origOffset, err := data.Seek(0, io.SeekCurrent) if err != nil { - return fmt.Errorf("Find current data offset failed: %v", err) + return 0, fmt.Errorf("Find current data offset failed: %v", err) } defer data.Seek(origOffset, io.SeekStart) // Move to the first block that contains target data. if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { - return fmt.Errorf("Seek to datablock start failed: %v", err) + return 0, fmt.Errorf("Seek to datablock start failed: %v", err) } buf := make([]byte, layout.blockSize) var readErr error - bytesRead := 0 + total := int64(0) for i := firstDataBlock; i <= lastDataBlock; i++ { // Read a block that includes all or part of target range in // input data. - bytesRead, readErr = data.Read(buf) + bytesRead, err := data.Read(buf) + readErr = err // If at the end of input data and all previous blocks are // verified, return the verified input data and EOF. if readErr == io.EOF && bytesRead == 0 { break } if readErr != nil && readErr != io.EOF { - return fmt.Errorf("Read from data failed: %v", err) + return 0, fmt.Errorf("Read from data failed: %v", err) } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. @@ -274,7 +275,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in } } if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil { - return err + return 0, err } // startOff is the beginning of the read range within the // current data block. Note that for all blocks other than the @@ -298,10 +299,14 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in if endOff > int64(bytesRead) { endOff = int64(bytesRead) } - w.Write(buf[startOff:endOff]) + n, err := w.Write(buf[startOff:endOff]) + if err != nil { + return total, err + } + total += int64(n) } - return readErr + return total, readErr } // verifyBlock verifies a block against tree. index is the number of block in diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index ad50ba5f6..daaca759a 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -67,17 +67,17 @@ func TestLayout(t *testing.T) { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) if l.blockSize != int64(usermem.PageSize) { - t.Errorf("got blockSize %d, want %d", l.blockSize, usermem.PageSize) + t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } if l.digestSize != sha256DigestSize { - t.Errorf("got digestSize %d, want %d", l.digestSize, sha256DigestSize) + t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { - t.Errorf("got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) + t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) } for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ { if l.levelOffset[i] != tc.expectedLevelOffset[i] { - t.Errorf("got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) + t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) } } }) @@ -169,11 +169,11 @@ func TestGenerate(t *testing.T) { }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) } if err != nil { - t.Fatalf("got err: %v, want nil", err) + t.Fatalf("Got err: %v, want nil", err) } if !bytes.Equal(root, tc.expectedRoot) { - t.Errorf("got root: %v, want %v", root, tc.expectedRoot) + t.Errorf("Got root: %v, want %v", root, tc.expectedRoot) } } }) @@ -334,14 +334,21 @@ func TestVerify(t *testing.T) { var buf bytes.Buffer data[tc.modifyByte] ^= 1 if tc.shouldSucceed { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err != nil && err != io.EOF { + n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { t.Errorf("Verification failed when expected to succeed: %v", err) } - if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") } } else { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { t.Errorf("Verification succeeded when expected to fail") } } @@ -382,14 +389,21 @@ func TestVerifyRandom(t *testing.T) { var buf bytes.Buffer // Checks that the random portion of data from the original data is // verified successfully. - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err != nil && err != io.EOF { + n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { t.Errorf("Verification failed for correct data: %v", err) } if size > dataSize-start { size = dataSize - start } - if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") } buf.Reset() @@ -397,7 +411,7 @@ func TestVerifyRandom(t *testing.T) { randBytePos := rand.Int63n(size) data[start+randBytePos] ^= 1 - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { t.Errorf("Verification succeeded for modified data") } } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 2ee07b664..28fe081d6 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -54,6 +54,8 @@ func (c *Client) newFile(fid FID) *clientFile { // // This proxies all of the interfaces found in file.go. type clientFile struct { + DisallowServerCalls + // client is the originating client. client *Client @@ -283,6 +285,39 @@ func (c *clientFile) Close() error { return nil } +// SetAttrClose implements File.SetAttrClose. +func (c *clientFile) SetAttrClose(valid SetAttrMask, attr SetAttr) error { + if !versionSupportsTsetattrclunk(c.client.version) { + setAttrErr := c.SetAttr(valid, attr) + + // Try to close file even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file + // again. + if err := c.Close(); err != nil { + return err + } + + return setAttrErr + } + + // Avoid double close. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return syscall.EBADF + } + + // Send the message. + if err := c.client.sendRecv(&Tsetattrclunk{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattrclunk{}); err != nil { + // If an error occurred, we toss away the FID. This isn't ideal, + // but I'm not sure what else makes sense in this context. + log.Warningf("Tsetattrclunk failed, losing FID %v: %v", c.fid, err) + return err + } + + // Return the FID to the pool. + c.client.fidPool.Put(uint64(c.fid)) + return nil +} + // Open implements File.Open. func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) { if atomic.LoadUint32(&c.closed) != 0 { @@ -681,6 +716,3 @@ func (c *clientFile) Flush() error { return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{}) } - -// Renamed implements File.Renamed. -func (c *clientFile) Renamed(newDir File, newName string) {} diff --git a/pkg/p9/file.go b/pkg/p9/file.go index cab35896f..c2e3a3f98 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -135,6 +135,14 @@ type File interface { // On the server, Close has no concurrency guarantee. Close() error + // SetAttrClose is the equivalent of calling SetAttr() followed by Close(). + // This can be used to set file times before closing the file in a single + // operation. + // + // On the server, SetAttr has a write concurrency guarantee. + // On the server, Close has no concurrency guarantee. + SetAttrClose(valid SetAttrMask, attr SetAttr) error + // Open must be called prior to using Read, Write or Readdir. Once Open // is called, some operations, such as Walk, will no longer work. // @@ -286,3 +294,19 @@ type DefaultWalkGetAttr struct{} func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS } + +// DisallowClientCalls panics if a client-only function is called. +type DisallowClientCalls struct{} + +// SetAttrClose implements File.SetAttrClose. +func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { + panic("SetAttrClose should not be called on the server") +} + +// DisallowServerCalls panics if a server-only function is called. +type DisallowServerCalls struct{} + +// Renamed implements File.Renamed. +func (*clientFile) Renamed(File, string) { + panic("Renamed should not be called on the client") +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 1db5797dd..abd237f46 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -123,6 +123,37 @@ func (t *Tclunk) handle(cs *connState) message { return &Rclunk{} } +func (t *Tsetattrclunk) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + setAttrErr := ref.safelyWrite(func() error { + // We don't allow setattr on files that have been deleted. + // This might be technically incorrect, as it's possible that + // there were multiple links and you can still change the + // corresponding inode information. + if ref.isDeleted() { + return syscall.EINVAL + } + + // Set the attributes. + return ref.file.SetAttr(t.Valid, t.SetAttr) + }) + + // Try to delete FID even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file again. + if !cs.DeleteFID(t.FID) { + return newErr(syscall.EBADF) + } + if setAttrErr != nil { + return newErr(setAttrErr) + } + return &Rsetattrclunk{} +} + // handle implements handler.handle. func (t *Tremove) handle(cs *connState) message { ref, ok := cs.LookupFID(t.FID) diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 2cb59f934..cf13cbb69 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -317,6 +317,64 @@ func (r *Rclunk) String() string { return "Rclunk{}" } +// Tsetattrclunk is a setattr+close request. +type Tsetattrclunk struct { + // FID is the FID to change. + FID FID + + // Valid is the set of bits which will be used. + Valid SetAttrMask + + // SetAttr is the set request. + SetAttr SetAttr +} + +// decode implements encoder.decode. +func (t *Tsetattrclunk) decode(b *buffer) { + t.FID = b.ReadFID() + t.Valid.decode(b) + t.SetAttr.decode(b) +} + +// encode implements encoder.encode. +func (t *Tsetattrclunk) encode(b *buffer) { + b.WriteFID(t.FID) + t.Valid.encode(b) + t.SetAttr.encode(b) +} + +// Type implements message.Type. +func (*Tsetattrclunk) Type() MsgType { + return MsgTsetattrclunk +} + +// String implements fmt.Stringer. +func (t *Tsetattrclunk) String() string { + return fmt.Sprintf("Tsetattrclunk{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr) +} + +// Rsetattrclunk is a setattr+close response. +type Rsetattrclunk struct { +} + +// decode implements encoder.decode. +func (*Rsetattrclunk) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rsetattrclunk) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rsetattrclunk) Type() MsgType { + return MsgRsetattrclunk +} + +// String implements fmt.Stringer. +func (r *Rsetattrclunk) String() string { + return "Rsetattrclunk{}" +} + // Tremove is a remove request. // // This will eventually be replaced by Tunlinkat. @@ -2657,6 +2715,8 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) + msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index 7facc9f5e..bfeb6c236 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -376,6 +376,30 @@ func TestEncodeDecode(t *testing.T) { &Rumknod{ Rmknod{QID: QID{Type: 1}}, }, + &Tsetattrclunk{ + FID: 1, + Valid: SetAttrMask{ + Permissions: true, + UID: true, + GID: true, + Size: true, + ATime: true, + MTime: true, + CTime: true, + ATimeNotSystemTime: true, + MTimeNotSystemTime: true, + }, + SetAttr: SetAttr{ + Permissions: 1, + UID: 2, + GID: 3, + Size: 4, + ATimeSeconds: 5, + ATimeNanoSeconds: 6, + MTimeSeconds: 7, + MTimeNanoSeconds: 8, + }, + }, } for _, enc := range objs { diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 122c457d2..2235f8968 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -315,86 +315,88 @@ type MsgType uint8 // MsgType declarations. const ( - MsgTlerror MsgType = 6 - MsgRlerror = 7 - MsgTstatfs = 8 - MsgRstatfs = 9 - MsgTlopen = 12 - MsgRlopen = 13 - MsgTlcreate = 14 - MsgRlcreate = 15 - MsgTsymlink = 16 - MsgRsymlink = 17 - MsgTmknod = 18 - MsgRmknod = 19 - MsgTrename = 20 - MsgRrename = 21 - MsgTreadlink = 22 - MsgRreadlink = 23 - MsgTgetattr = 24 - MsgRgetattr = 25 - MsgTsetattr = 26 - MsgRsetattr = 27 - MsgTlistxattr = 28 - MsgRlistxattr = 29 - MsgTxattrwalk = 30 - MsgRxattrwalk = 31 - MsgTxattrcreate = 32 - MsgRxattrcreate = 33 - MsgTgetxattr = 34 - MsgRgetxattr = 35 - MsgTsetxattr = 36 - MsgRsetxattr = 37 - MsgTremovexattr = 38 - MsgRremovexattr = 39 - MsgTreaddir = 40 - MsgRreaddir = 41 - MsgTfsync = 50 - MsgRfsync = 51 - MsgTlink = 70 - MsgRlink = 71 - MsgTmkdir = 72 - MsgRmkdir = 73 - MsgTrenameat = 74 - MsgRrenameat = 75 - MsgTunlinkat = 76 - MsgRunlinkat = 77 - MsgTversion = 100 - MsgRversion = 101 - MsgTauth = 102 - MsgRauth = 103 - MsgTattach = 104 - MsgRattach = 105 - MsgTflush = 108 - MsgRflush = 109 - MsgTwalk = 110 - MsgRwalk = 111 - MsgTread = 116 - MsgRread = 117 - MsgTwrite = 118 - MsgRwrite = 119 - MsgTclunk = 120 - MsgRclunk = 121 - MsgTremove = 122 - MsgRremove = 123 - MsgTflushf = 124 - MsgRflushf = 125 - MsgTwalkgetattr = 126 - MsgRwalkgetattr = 127 - MsgTucreate = 128 - MsgRucreate = 129 - MsgTumkdir = 130 - MsgRumkdir = 131 - MsgTumknod = 132 - MsgRumknod = 133 - MsgTusymlink = 134 - MsgRusymlink = 135 - MsgTlconnect = 136 - MsgRlconnect = 137 - MsgTallocate = 138 - MsgRallocate = 139 - MsgTchannel = 250 - MsgRchannel = 251 + MsgTlerror MsgType = 6 + MsgRlerror MsgType = 7 + MsgTstatfs MsgType = 8 + MsgRstatfs MsgType = 9 + MsgTlopen MsgType = 12 + MsgRlopen MsgType = 13 + MsgTlcreate MsgType = 14 + MsgRlcreate MsgType = 15 + MsgTsymlink MsgType = 16 + MsgRsymlink MsgType = 17 + MsgTmknod MsgType = 18 + MsgRmknod MsgType = 19 + MsgTrename MsgType = 20 + MsgRrename MsgType = 21 + MsgTreadlink MsgType = 22 + MsgRreadlink MsgType = 23 + MsgTgetattr MsgType = 24 + MsgRgetattr MsgType = 25 + MsgTsetattr MsgType = 26 + MsgRsetattr MsgType = 27 + MsgTlistxattr MsgType = 28 + MsgRlistxattr MsgType = 29 + MsgTxattrwalk MsgType = 30 + MsgRxattrwalk MsgType = 31 + MsgTxattrcreate MsgType = 32 + MsgRxattrcreate MsgType = 33 + MsgTgetxattr MsgType = 34 + MsgRgetxattr MsgType = 35 + MsgTsetxattr MsgType = 36 + MsgRsetxattr MsgType = 37 + MsgTremovexattr MsgType = 38 + MsgRremovexattr MsgType = 39 + MsgTreaddir MsgType = 40 + MsgRreaddir MsgType = 41 + MsgTfsync MsgType = 50 + MsgRfsync MsgType = 51 + MsgTlink MsgType = 70 + MsgRlink MsgType = 71 + MsgTmkdir MsgType = 72 + MsgRmkdir MsgType = 73 + MsgTrenameat MsgType = 74 + MsgRrenameat MsgType = 75 + MsgTunlinkat MsgType = 76 + MsgRunlinkat MsgType = 77 + MsgTversion MsgType = 100 + MsgRversion MsgType = 101 + MsgTauth MsgType = 102 + MsgRauth MsgType = 103 + MsgTattach MsgType = 104 + MsgRattach MsgType = 105 + MsgTflush MsgType = 108 + MsgRflush MsgType = 109 + MsgTwalk MsgType = 110 + MsgRwalk MsgType = 111 + MsgTread MsgType = 116 + MsgRread MsgType = 117 + MsgTwrite MsgType = 118 + MsgRwrite MsgType = 119 + MsgTclunk MsgType = 120 + MsgRclunk MsgType = 121 + MsgTremove MsgType = 122 + MsgRremove MsgType = 123 + MsgTflushf MsgType = 124 + MsgRflushf MsgType = 125 + MsgTwalkgetattr MsgType = 126 + MsgRwalkgetattr MsgType = 127 + MsgTucreate MsgType = 128 + MsgRucreate MsgType = 129 + MsgTumkdir MsgType = 130 + MsgRumkdir MsgType = 131 + MsgTumknod MsgType = 132 + MsgRumknod MsgType = 133 + MsgTusymlink MsgType = 134 + MsgRusymlink MsgType = 135 + MsgTlconnect MsgType = 136 + MsgRlconnect MsgType = 137 + MsgTallocate MsgType = 138 + MsgRallocate MsgType = 139 + MsgTsetattrclunk MsgType = 140 + MsgRsetattrclunk MsgType = 141 + MsgTchannel MsgType = 250 + MsgRchannel MsgType = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e7bb3db2..6e605b14c 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1225,22 +1225,31 @@ func TestOpen(t *testing.T) { func TestClose(t *testing.T) { type closeTest struct { name string - closeFn func(backend *Mock, f p9.File) + closeFn func(backend *Mock, f p9.File) error } cases := []closeTest{ { name: "close", - closeFn: func(_ *Mock, f p9.File) { - f.Close() + closeFn: func(_ *Mock, f p9.File) error { + return f.Close() }, }, { name: "remove", - closeFn: func(backend *Mock, f p9.File) { + closeFn: func(backend *Mock, f p9.File) error { // Allow the rename call in the parent, automatically translated. backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - f.(deprecatedRemover).Remove() + return f.(deprecatedRemover).Remove() + }, + }, + { + name: "setAttrClose", + closeFn: func(backend *Mock, f p9.File) error { + valid := p9.SetAttrMask{ATime: true} + attr := p9.SetAttr{ATimeSeconds: 1, ATimeNanoSeconds: 2} + backend.EXPECT().SetAttr(valid, attr).Times(1) + return f.SetAttrClose(valid, attr) }, }, } @@ -1258,7 +1267,9 @@ func TestClose(t *testing.T) { _, backend, f := walkHelper(h, name, root) // Close via the prescribed method. - tc.closeFn(backend, f) + if err := tc.closeFn(backend, f); err != nil { + t.Fatalf("closeFn failed: %v", err) + } // Everything should fail with EBADF. if _, _, err := f.Walk(nil); err != syscall.EBADF { diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 09cde9f5a..8d7168ef5 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 11 + highestSupportedVersion uint32 = 12 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -173,3 +173,9 @@ func versionSupportsGetSetXattr(v uint32) bool { func versionSupportsListRemoveXattr(v uint32) bool { return v >= 11 } + +// versionSupportsTsetattrclunk returns true if version v supports +// the Tsetattrclunk message. +func versionSupportsTsetattrclunk(v uint32) bool { + return v >= 12 +} diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 57d8542b9..699ea8ac3 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -257,6 +257,8 @@ func (l *LeakMode) Get() interface{} { // String implements flag.Value. func (l *LeakMode) String() string { switch *l { + case UninitializedLeakChecking: + return "uninitialized" case NoLeakChecking: return "disabled" case LeaksLogWarning: @@ -264,7 +266,7 @@ func (l *LeakMode) String() string { case LeaksLogTraces: return "log-traces" } - panic(fmt.Sprintf("invalid ref leak mode %q", *l)) + panic(fmt.Sprintf("invalid ref leak mode %d", *l)) } // leakMode stores the current mode for the reference leak checker. diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD index ce30382ab..68ed074f8 100644 --- a/pkg/safemem/BUILD +++ b/pkg/safemem/BUILD @@ -11,9 +11,7 @@ go_library( "seq_unsafe.go", ], visibility = ["//:sandbox"], - deps = [ - "//pkg/safecopy", - ], + deps = ["//pkg/safecopy"], ) go_test( diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 55fd6967e..752e2dc32 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package seccomp provides basic seccomp filters for x86_64 (little endian). +// Package seccomp provides generation of basic seccomp filters. Currently, +// only little endian systems are supported. package seccomp import ( @@ -64,9 +65,9 @@ func Install(rules SyscallRules) error { Rules: rules, Action: linux.SECCOMP_RET_ALLOW, }, - }, defaultAction) + }, defaultAction, defaultAction) if log.IsLogging(log.Debug) { - programStr, errDecode := bpf.DecodeProgram(instrs) + programStr, errDecode := bpf.DecodeInstructions(instrs) if errDecode != nil { programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr) } @@ -117,7 +118,7 @@ var SyscallName = func(sysno uintptr) string { // BuildProgram builds a BPF program from the given map of actions to matching // SyscallRules. The single generated program covers all provided RuleSets. -func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) { +func BuildProgram(rules []RuleSet, defaultAction, badArchAction linux.BPFAction) ([]linux.BPFInstruction, error) { program := bpf.NewProgramBuilder() // Be paranoid and check that syscall is done in the expected architecture. @@ -128,7 +129,7 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // defaultLabel is at the bottom of the program. The size of program // may exceeds 255 lines, which is the limit of a condition jump. program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0) - program.AddDirectJumpLabel(defaultLabel) + program.AddStmt(bpf.Ret|bpf.K, uint32(badArchAction)) if err := buildIndex(rules, program); err != nil { return nil, err } @@ -144,6 +145,11 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // buildIndex builds a BST to quickly search through all syscalls. func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error { + // Do nothing if rules is empty. + if len(rules) == 0 { + return nil + } + // Build a list of all application system calls, across all given rule // sets. We have a simple BST, but may dispatch individual matchers // with different actions. The matchers are evaluated linearly. @@ -216,42 +222,163 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc labelled := false for i, arg := range rule { if arg != nil { + // Break out early if using MatchAny since no further + // instructions are required. + if _, ok := arg.(MatchAny); ok { + continue + } + + // Determine the data offset for low and high bits of input. + dataOffsetLow := seccompDataOffsetArgLow(i) + dataOffsetHigh := seccompDataOffsetArgHigh(i) + if i == RuleIP { + dataOffsetLow = seccompDataOffsetIPLow + dataOffsetHigh = seccompDataOffsetIPHigh + } + + // Add the conditional operation. Input values to the BPF + // program are 64bit values. However, comparisons in BPF can + // only be done on 32bit values. This means that we need to do + // multiple BPF comparisons in order to do one logical 64bit + // comparison. switch a := arg.(type) { - case AllowAny: - case AllowValue: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } + case EqualTo: + // EqualTo checks that both the higher and lower 32bits are equal. high, low := uint32(a>>32), uint32(a) - // assert arg_low == low + + // Assert that the lower 32bits are equal. + // arg_low == low ? continue : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // assert arg_high == high + + // Assert that the lower 32bits are also equal. + // arg_high == high ? continue/success : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case NotEqual: + // NotEqual checks that either the higher or lower 32bits + // are *not* equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ne%v", i) + + // Check if the higher 32bits are (not) equal. + // arg_low == low ? continue : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are not equal (assuming + // higher bits are equal). + // arg_high == high ? violation : continue/success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true case GreaterThan: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } - labelGood := fmt.Sprintf("gt%v", i) + // GreaterThan checks that the higher 32bits is greater + // *or* that the higher 32bits are equal and the lower + // 32bits are greater. high, low := uint32(a>>32), uint32(a) - // assert arg_high < high + labelGood := fmt.Sprintf("gt%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // arg_high > high + + // Assert that the lower 32bits are greater. + // arg_high == high ? continue : success (arg_high > high) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) - // arg_low < low + // arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) labelled = true + case GreaterThanOrEqual: + // GreaterThanOrEqual checks that the higher 32bits is + // greater *or* that the higher 32bits are equal and the + // lower 32bits are greater than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ge%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high == high ? continue : success (arg_high > high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are greater (assuming the + // higher bits are equal). + // arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThan: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("lt%v", i) + + // Assert the higher 32bits are less than or equal. + // arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success (arg_high < high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are less (assuming the + // higher bits are equal). + // arg_low >= low ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jge|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThanOrEqual: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("le%v", i) + + // Assert the higher 32bits are less than or equal. + // assert arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert the lower bits are less than or equal (assuming + // the higher bits are equal). + // arg_low > low ? violation : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case maskedEqual: + // MaskedEqual checks that the bitwise AND of the value and + // mask are equal for both the higher and lower 32bits. + high, low := uint32(a.value>>32), uint32(a.value) + maskHigh, maskLow := uint32(a.mask>>32), uint32(a.mask) + + // Assert that the lower 32bits are equal when masked. + // A <- arg_low. + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + // A <- arg_low & maskLow + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskLow) + // Assert that arg_low & maskLow == low. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + + // Assert that the higher 32bits are equal when masked. + // A <- arg_high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + // A <- arg_high & maskHigh + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskHigh) + // Assert that arg_high & maskHigh == high. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index a52dc1b4e..daf165bbf 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -39,28 +39,79 @@ func seccompDataOffsetArgHigh(i int) uint32 { return seccompDataOffsetArgLow(i) + 4 } -// AllowAny is marker to indicate any value will be accepted. -type AllowAny struct{} +// MatchAny is marker to indicate any value will be accepted. +type MatchAny struct{} -func (a AllowAny) String() (s string) { +func (a MatchAny) String() (s string) { return "*" } -// AllowValue specifies a value that needs to be strictly matched. -type AllowValue uintptr +// EqualTo specifies a value that needs to be strictly matched. +type EqualTo uintptr + +func (a EqualTo) String() (s string) { + return fmt.Sprintf("== %#x", uintptr(a)) +} + +// NotEqual specifies a value that is strictly not equal. +type NotEqual uintptr + +func (a NotEqual) String() (s string) { + return fmt.Sprintf("!= %#x", uintptr(a)) +} // GreaterThan specifies a value that needs to be strictly smaller. type GreaterThan uintptr -func (a AllowValue) String() (s string) { - return fmt.Sprintf("%#x ", uintptr(a)) +func (a GreaterThan) String() (s string) { + return fmt.Sprintf("> %#x", uintptr(a)) +} + +// GreaterThanOrEqual specifies a value that needs to be smaller or equal. +type GreaterThanOrEqual uintptr + +func (a GreaterThanOrEqual) String() (s string) { + return fmt.Sprintf(">= %#x", uintptr(a)) +} + +// LessThan specifies a value that needs to be strictly greater. +type LessThan uintptr + +func (a LessThan) String() (s string) { + return fmt.Sprintf("< %#x", uintptr(a)) +} + +// LessThanOrEqual specifies a value that needs to be greater or equal. +type LessThanOrEqual uintptr + +func (a LessThanOrEqual) String() (s string) { + return fmt.Sprintf("<= %#x", uintptr(a)) +} + +type maskedEqual struct { + mask uintptr + value uintptr +} + +func (a maskedEqual) String() (s string) { + return fmt.Sprintf("& %#x == %#x", a.mask, a.value) +} + +// MaskedEqual specifies a value that matches the input after the input is +// masked (bitwise &) against the given mask. Can be used to verify that input +// only includes certain approved flags. +func MaskedEqual(mask, value uintptr) interface{} { + return maskedEqual{ + mask: mask, + value: value, + } } // Rule stores the allowed syscall arguments. // // For example: // rule := Rule { -// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 +// EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 // } type Rule [7]interface{} // 6 arguments + RIP @@ -89,12 +140,12 @@ func (r Rule) String() (s string) { // rules := SyscallRules{ // syscall.SYS_FUTEX: []Rule{ // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), // }, // OR // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), // }, // }, // syscall.SYS_GETPID: []Rule{}, diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 5238df8bd..23f30678d 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -76,11 +76,14 @@ func TestBasic(t *testing.T) { } for _, test := range []struct { + name string ruleSets []RuleSet defaultAction linux.BPFAction + badArchAction linux.BPFAction specs []spec }{ { + name: "Single syscall", ruleSets: []RuleSet{ { Rules: SyscallRules{1: {}}, @@ -88,26 +91,28 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Single syscall allowed", + desc: "syscall allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Single syscall disallowed", + desc: "syscall disallowed", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple rulesets", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0x1), + EqualTo(0x1), }, }, }, @@ -122,30 +127,32 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_KILL_THREAD, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple rulesets allowed (1a)", + desc: "allowed (1a)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple rulesets allowed (1b)", + desc: "allowed (1b)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", + desc: "syscall 1 matched 2nd rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", + desc: "no match", data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Multiple syscalls", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -157,50 +164,52 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple syscalls allowed (1)", + desc: "allowed (1)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (3)", + desc: "allowed (3)", data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (5)", + desc: "allowed (5)", data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls disallowed (0)", + desc: "disallowed (0)", data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (2)", + desc: "disallowed (2)", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (4)", + desc: "disallowed (4)", data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (6)", + desc: "disallowed (6)", data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (100)", + desc: "disallowed (100)", data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Wrong architecture", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -210,15 +219,17 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Wrong architecture", + desc: "arch (123)", data: seccompData{nr: 1, arch: 123}, - want: linux.SECCOMP_RET_TRAP, + want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Syscall disallowed", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -228,22 +239,24 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall disallowed, action trap", + desc: "action trap", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Syscall arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowAny{}, - AllowValue(0xf), + MatchAny{}, + EqualTo(0xf), }, }, }, @@ -251,29 +264,31 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed", + desc: "allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument disallowed", + desc: "disallowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0xf), + EqualTo(0xf), }, { - AllowValue(0xe), + EqualTo(0xe), }, }, }, @@ -281,28 +296,30 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed, two rules", + desc: "match first rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument allowed, two rules", + desc: "match 2nd rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, }, { + name: "EqualTo", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0), - AllowValue(math.MaxUint64 - 1), - AllowValue(math.MaxUint32), + EqualTo(0), + EqualTo(math.MaxUint64 - 1), + EqualTo(math.MaxUint32), }, }, }, @@ -310,9 +327,10 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "64bit syscall argument allowed", + desc: "argument allowed (all match)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -321,7 +339,7 @@ func TestBasic(t *testing.T) { want: linux.SECCOMP_RET_ALLOW, }, { - desc: "64bit syscall argument disallowed", + desc: "argument disallowed (one mismatch)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -330,7 +348,7 @@ func TestBasic(t *testing.T) { want: linux.SECCOMP_RET_TRAP, }, { - desc: "64bit syscall argument disallowed", + desc: "argument disallowed (multiple mismatch)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -341,6 +359,103 @@ func TestBasic(t *testing.T) { }, }, { + name: "NotEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + NotEqual(0x7aabbccdd), + NotEqual(math.MaxUint64 - 1), + NotEqual(math.MaxUint32), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (one equal)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (all equal)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThan(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan (multi)", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -355,46 +470,145 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "GreaterThan: Syscall argument allowed", + desc: "arg allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan: Syscall argument disallowed (equal)", + desc: "arg disallowed (first arg equal)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Syscall argument disallowed (smaller)", + desc: "arg disallowed (first arg smaller)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}}, + desc: "arg disallowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan2: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThanOrEqual(0xf), + GreaterThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (both greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument disallowed (smaller)", + desc: "arg allowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg smaller)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, + { + desc: "arg disallowed (both arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, }, }, { + name: "LessThan", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - RuleIP: AllowValue(0x7aabbccdd), + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThan(0x00000002_00000002), }, }, }, @@ -402,40 +616,307 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "IP: Syscall instruction pointer allowed", + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + { + name: "LessThan (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThan(0x1), + LessThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (first arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "LessThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + + { + name: "LessThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThanOrEqual(0x1), + LessThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg allowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "MaskedEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1) + // Input x must have lowest order bit set and + // must *not* have 8th or second lowest order bit set. + MaskedEqual(0x103, 0x1), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (low order mandatory bit)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000001 + args: [6]uint64{0x1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (low order optional bit)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000101 + args: [6]uint64{0x5}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (lowest order bit not set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000010 + args: [6]uint64{0x2}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second lowest order bit set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000011 + args: [6]uint64{0x3}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (8th bit set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000001 00000000 + args: [6]uint64{0x100}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "Instruction Pointer", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + RuleIP: EqualTo(0x7aabbccdd), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "IP: Syscall instruction pointer disallowed", + desc: "disallowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, }, } { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction) - if err != nil { - t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err) - continue - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err) - continue - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) + t.Run(test.name, func(t *testing.T) { + instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction) if err != nil { - t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err) - continue + t.Fatalf("BuildProgram() got error: %v", err) + } + p, err := bpf.Compile(instrs) + if err != nil { + t.Fatalf("bpf.Compile() got error: %v", err) } - if got != uint32(spec.want) { - t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want) + for _, spec := range test.specs { + got, err := bpf.Exec(p, spec.data.asInput()) + if err != nil { + t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) + } + if got != uint32(spec.want) { + // Include a decoded version of the program in output for debugging purposes. + decoded, _ := bpf.DecodeInstructions(instrs) + t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded) + } } - } + }) } } @@ -457,7 +938,7 @@ func TestRandom(t *testing.T) { Rules: syscallRules, Action: linux.SECCOMP_RET_ALLOW, }, - }, linux.SECCOMP_RET_TRAP) + }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD) if err != nil { t.Fatalf("buildProgram() got error: %v", err) } diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index fe157f539..7f33e0d9e 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -100,7 +100,7 @@ func main() { if !die { syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ { - seccomp.AllowValue(10), + seccomp.EqualTo(10), }, } } diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 901e0f320..4af4d6e84 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -22,6 +22,7 @@ go_library( "signal_info.go", "signal_stack.go", "stack.go", + "stack_unsafe.go", "syscalls_amd64.go", "syscalls_arm64.go", ], @@ -33,11 +34,12 @@ go_library( "//pkg/context", "//pkg/cpuid", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/limits", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index a903d031c..d75d665ae 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -72,12 +73,12 @@ type Context interface { // with return values of varying sizes (for example ARCH_GETFS). This // is a simple utility function to convert to the native size in these // cases, and then we can CopyOut. - Native(val uintptr) interface{} + Native(val uintptr) marshal.Marshallable // Value converts a native type back to a generic value. // Once a value has been converted to native via the above call -- it // can be converted back here. - Value(val interface{}) uintptr + Value(val marshal.Marshallable) uintptr // Width returns the number of bytes for a native value. Width() uint @@ -205,7 +206,7 @@ type Context interface { // equivalent of arch_ptrace(): // PtracePeekUser implements ptrace(PTRACE_PEEKUSR). - PtracePeekUser(addr uintptr) (interface{}, error) + PtracePeekUser(addr uintptr) (marshal.Marshallable, error) // PtracePokeUser implements ptrace(PTRACE_POKEUSR). PtracePokeUser(addr, data uintptr) error diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 1c3e3c14c..c7d3a206d 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -23,6 +23,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -179,14 +181,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -293,7 +295,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { const userStructSize = 928 // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { if addr&7 != 0 || addr >= userStructSize { return nil, syscall.EIO } diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 550741d8c..680d23a9f 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -22,6 +22,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -163,14 +165,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -274,7 +276,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { } // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64. return c.Native(0), nil } diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go index 32173aa20..d3e2324a8 100644 --- a/pkg/sentry/arch/signal_act.go +++ b/pkg/sentry/arch/signal_act.go @@ -14,7 +14,7 @@ package arch -import "gvisor.dev/gvisor/tools/go_marshal/marshal" +import "gvisor.dev/gvisor/pkg/marshal" // Special values for SignalAct.Handler. const ( diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 6fb756f0e..72e07a988 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -17,17 +17,19 @@ package arch import ( - "encoding/binary" "math" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { R8 uint64 R9 uint64 @@ -68,6 +70,8 @@ const ( ) // UContext64 is equivalent to ucontext_t on 64-bit x86. +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 @@ -172,12 +176,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // "... the value (%rsp+8) is always a multiple of 16 (...) when // control is transferred to the function entry point." - AMD64 ABI - ucSize := binary.Size(uc) - if ucSize < 0 { - // This can only happen if we've screwed up the definition of - // UContext64. - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128. frameSize := int(st.Arch.Width()) + ucSize + 128 frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8 @@ -195,18 +194,18 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom if act.HasRestorer() { // Push the restorer return address. // Note that this doesn't need to be popped. - if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil { + if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { return err } } else { @@ -240,11 +239,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 642c79dda..7fde5d34e 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package arch import ( - "encoding/binary" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +26,8 @@ import ( // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { FaultAddr uint64 Regs [31]uint64 @@ -36,6 +39,7 @@ type SignalContext64 struct { Reserved [3568]uint8 } +// +marshal type aarch64Ctx struct { Magic uint32 Size uint32 @@ -43,6 +47,8 @@ type aarch64Ctx struct { // FpsimdContext is equivalent to struct fpsimd_context on arm64 // (arch/arm64/include/uapi/asm/sigcontext.h). +// +// +marshal type FpsimdContext struct { Head aarch64Ctx Fpsr uint32 @@ -51,13 +57,15 @@ type FpsimdContext struct { } // UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h). +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 Stack SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t - _pad [(1024 - 64) / 8]byte + _pad [120]byte // (1024 - 64) / 8 = 120 // sigcontext must be aligned to 16-byte _pad2 [8]byte // last for future expansion @@ -94,11 +102,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt }, Sigset: sigset, } - - ucSize := binary.Size(uc) - if ucSize < 0 { - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // frameSize = ucSize + sizeof(siginfo). // sizeof(siginfo) == 128. @@ -119,14 +123,14 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom // Set up registers. c.Regs.Sp = uint64(st.Bottom) @@ -147,11 +151,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index 0fa738a1d..a1eae98f9 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -17,8 +17,8 @@ package arch import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) const ( diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 1108fa0bd..5f06c751d 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -15,14 +15,16 @@ package arch import ( - "encoding/binary" - "fmt" - "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) -// Stack is a simple wrapper around a usermem.IO and an address. +// Stack is a simple wrapper around a usermem.IO and an address. Stack +// implements marshal.CopyContext, and marshallable values can be pushed or +// popped from the stack through the marshal.Marshallable interface. +// +// Stack is not thread-safe. type Stack struct { // Our arch info. // We use this for automatic Native conversion of usermem.Addrs during @@ -34,105 +36,60 @@ type Stack struct { // Our current stack bottom. Bottom usermem.Addr -} -// Push pushes the given values on to the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - // We convert some types to well-known serializable quanities. - var norm interface{} - - // For array types, we will automatically add an appropriate - // terminal value. This is done simply to make the interface - // easier to use. - var term interface{} - - switch v.(type) { - case string: - norm = []byte(v.(string)) - term = byte(0) - case []int8, []uint8: - norm = v - term = byte(0) - case []int16, []uint16: - norm = v - term = uint16(0) - case []int32, []uint32: - norm = v - term = uint32(0) - case []int64, []uint64: - norm = v - term = uint64(0) - case []usermem.Addr: - // Special case: simply push recursively. - _, err := s.Push(s.Arch.Native(uintptr(0))) - if err != nil { - return 0, err - } - varr := v.([]usermem.Addr) - for i := len(varr) - 1; i >= 0; i-- { - _, err := s.Push(varr[i]) - if err != nil { - return 0, err - } - } - continue - case usermem.Addr: - norm = s.Arch.Native(uintptr(v.(usermem.Addr))) - default: - norm = v - } + // Scratch buffer used for marshalling to avoid having to repeatedly + // allocate scratch memory. + scratchBuf []byte +} - if term != nil { - _, err := s.Push(term) - if err != nil { - return 0, err - } - } +// scratchBufLen is the default length of Stack.scratchBuf. The +// largest structs the stack regularly serializes are arch.SignalInfo +// and arch.UContext64. We'll set the default size as the larger of +// the two, arch.UContext64. +var scratchBufLen = (*UContext64)(nil).SizeBytes() - c := binary.Size(norm) - if c < 0 { - return 0, fmt.Errorf("bad binary.Size for %T", v) - } - n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{}) - if err != nil || c != n { - return 0, err - } +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (s *Stack) CopyScratchBuffer(size int) []byte { + if len(s.scratchBuf) < size { + s.scratchBuf = make([]byte, size) + } + return s.scratchBuf[:size] +} +// StackBottomMagic is the special address callers must past to all stack +// marshalling operations to cause the src/dst address to be computed based on +// the current end of the stack. +const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1) + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes +// computes an appropriate address based on the current end of the +// stack. Callers use the sentinel address StackBottomMagic to marshal methods +// to indicate this. +func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy out to stack with absolute address") + } + c := len(b) + n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{}) + if err == nil && n == c { s.Bottom -= usermem.Addr(n) } - - return s.Bottom, nil + return n, err } -// Pop pops the given values off the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - vaddr, isVaddr := v.(*usermem.Addr) - - var n int - var err error - if isVaddr { - value := s.Arch.Native(uintptr(0)) - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{}) - *vaddr = usermem.Addr(s.Arch.Value(value)) - } else { - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{}) - } - if err != nil { - return 0, err - } - +// CopyInBytes implements marshal.CopyContext.CopyInBytes. CopyInBytes computes +// an appropriate address based on the current end of the stack. Callers must +// use the sentinel address StackBottomMagic to marshal methods to indicate +// this. +func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy in from stack with absolute address") + } + n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{}) + if err == nil { s.Bottom += usermem.Addr(n) } - - return s.Bottom, nil + return n, err } // Align aligns the stack to the given offset. @@ -142,6 +99,22 @@ func (s *Stack) Align(offset int) { } } +// PushNullTerminatedByteSlice writes bs to the stack, followed by an extra null +// byte at the end. On error, the contents of the stack and the bottom cursor +// are undefined. +func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) { + // Note: Stack grows up, so write the terminal null byte first. + nNull, err := primitive.CopyUint8Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + n, err := primitive.CopyByteSliceOut(s, StackBottomMagic, bs) + if err != nil { + return 0, err + } + return n + nNull, nil +} + // StackLayout describes the location of the arguments and environment on the // stack. type StackLayout struct { @@ -177,11 +150,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.EnvvEnd = s.Bottom envAddrs := make([]usermem.Addr, len(env)) for i := len(env) - 1; i >= 0; i-- { - addr, err := s.Push(env[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil { return StackLayout{}, err } - envAddrs[i] = addr + envAddrs[i] = s.Bottom } l.EnvvStart = s.Bottom @@ -189,11 +161,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.ArgvEnd = s.Bottom argAddrs := make([]usermem.Addr, len(args)) for i := len(args) - 1; i >= 0; i-- { - addr, err := s.Push(args[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil { return StackLayout{}, err } - argAddrs[i] = addr + argAddrs[i] = s.Bottom } l.ArgvStart = s.Bottom @@ -222,26 +193,26 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) auxv = append(auxv, usermem.Addr(a.Key), a.Value) } auxv = append(auxv, usermem.Addr(0)) - _, err := s.Push(auxv) + _, err := s.pushAddrSliceAndTerminator(auxv) if err != nil { return StackLayout{}, err } // Push environment. - _, err = s.Push(envAddrs) + _, err = s.pushAddrSliceAndTerminator(envAddrs) if err != nil { return StackLayout{}, err } // Push args. - _, err = s.Push(argAddrs) + _, err = s.pushAddrSliceAndTerminator(argAddrs) if err != nil { return StackLayout{}, err } // Push arg count. - _, err = s.Push(usermem.Addr(len(args))) - if err != nil { + lenP := s.Arch.Native(uintptr(len(args))) + if _, err = lenP.CopyOut(s, StackBottomMagic); err != nil { return StackLayout{}, err } diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go new file mode 100644 index 000000000..a90d297ee --- /dev/null +++ b/pkg/sentry/arch/stack_unsafe.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arch + +import ( + "reflect" + "runtime" + "unsafe" + + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" +) + +// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and +// also pushes an extra null address element at the end of the slice. +// +// Internally, we unsafely transmute the slice type from the arch-dependent +// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to +// go-marshal. +// +// On error, the contents of the stack and the bottom cursor are undefined. +func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) { + // Note: Stack grows upwards, so push the terminator first. + srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + switch s.Arch.Width() { + case 8: + nNull, err := primitive.CopyUint64Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint64 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint64SliceOut(s, StackBottomMagic, dst) + // Ensures src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + case 4: + nNull, err := primitive.CopyUint32Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint32 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint32SliceOut(s, StackBottomMagic, dst) + // Ensure src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + default: + panic("Unsupported arch width") + } +} diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2c5d14be5..deaf5fa23 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -35,7 +35,6 @@ go_library( "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index dfa936563..668f47802 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -23,8 +23,8 @@ import ( "text/tabwriter" "time" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fdimport" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" @@ -203,27 +203,17 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI } initArgs.Filename = resolved - fds := make([]int, len(args.FilePayload.Files)) - for i, file := range args.FilePayload.Files { - if kernel.VFS2Enabled { - // Need to dup to remove ownership from os.File. - dup, err := unix.Dup(int(file.Fd())) - if err != nil { - return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) - } - fds[i] = dup - } else { - // VFS1 dups the file on import. - fds[i] = int(file.Fd()) - } + fds, err := fd.NewFromFiles(args.Files) + if err != nil { + return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) } + defer func() { + for _, fd := range fds { + _ = fd.Close() + } + }() ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds) if err != nil { - if kernel.VFS2Enabled { - for _, fd := range fds { - unix.Close(fd) - } - } return nil, 0, nil, nil, err } diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD index abe58f818..4c8604d58 100644 --- a/pkg/sentry/devices/memdev/BUILD +++ b/pkg/sentry/devices/memdev/BUILD @@ -18,9 +18,10 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go index 511179e31..fece3e762 100644 --- a/pkg/sentry/devices/memdev/full.go +++ b/pkg/sentry/devices/memdev/full.go @@ -24,6 +24,8 @@ import ( const fullDevMinor = 7 // fullDevice implements vfs.Device for /dev/full. +// +// +stateify savable type fullDevice struct{} // Open implements vfs.Device.Open. @@ -38,6 +40,8 @@ func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // fullFD implements vfs.FileDescriptionImpl for /dev/full. +// +// +stateify savable type fullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go index 4918dbeeb..ff5837747 100644 --- a/pkg/sentry/devices/memdev/null.go +++ b/pkg/sentry/devices/memdev/null.go @@ -25,6 +25,8 @@ import ( const nullDevMinor = 3 // nullDevice implements vfs.Device for /dev/null. +// +// +stateify savable type nullDevice struct{} // Open implements vfs.Device.Open. @@ -39,6 +41,8 @@ func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // nullFD implements vfs.FileDescriptionImpl for /dev/null. +// +// +stateify savable type nullFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go index 5e7fe0280..ac943e3ba 100644 --- a/pkg/sentry/devices/memdev/random.go +++ b/pkg/sentry/devices/memdev/random.go @@ -30,6 +30,8 @@ const ( ) // randomDevice implements vfs.Device for /dev/random and /dev/urandom. +// +// +stateify savable type randomDevice struct{} // Open implements vfs.Device.Open. @@ -44,6 +46,8 @@ func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, } // randomFD implements vfs.FileDescriptionImpl for /dev/random. +// +// +stateify savable type randomFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 2e631a252..1929e41cd 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -16,9 +16,10 @@ package memdev import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -26,6 +27,8 @@ import ( const zeroDevMinor = 5 // zeroDevice implements vfs.Device for /dev/zero. +// +// +stateify savable type zeroDevice struct{} // Open implements vfs.Device.Open. @@ -40,6 +43,8 @@ func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // zeroFD implements vfs.FileDescriptionImpl for /dev/zero. +// +// +stateify savable type zeroFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -79,11 +84,22 @@ func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) + if opts.Private || !opts.MaxPerms.Write { + // This mapping will never permit writing to the "underlying file" (in + // Linux terms, it isn't VM_SHARED), so implement it as an anonymous + // mapping, but back it with fd; this is what Linux does, and is + // actually application-visible because the resulting VMA will show up + // in /proc/[pid]/maps with fd.vfsfd.VirtualDentry()'s path rather than + // "/dev/zero (deleted)". + opts.Offset = 0 + opts.MappingIdentity = &fd.vfsfd + opts.MappingIdentity.IncRef() + return nil + } + tmpfsFD, err := tmpfs.NewZeroFile(ctx, auth.CredentialsFromContext(ctx), kernel.KernelFromContext(ctx).ShmMount(), opts.Length) if err != nil { return err } - opts.MappingIdentity = m - opts.Mappable = m - return nil + defer tmpfsFD.DecRef(ctx) + return tmpfsFD.ConfigureMMap(ctx, opts) } diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go index 664e54498..a287c65ca 100644 --- a/pkg/sentry/devices/ttydev/ttydev.go +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -30,6 +30,8 @@ const ( ) // ttyDevice implements vfs.Device for /dev/tty. +// +// +stateify savable type ttyDevice struct{} // Open implements vfs.Device.Open. diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD index 5e41ceb4e..6b4f8b0ed 100644 --- a/pkg/sentry/fdimport/BUILD +++ b/pkg/sentry/fdimport/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/fd", "//pkg/sentry/fs", "//pkg/sentry/fs/host", "//pkg/sentry/fsimpl/host", diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 1b7cb94c0..314661475 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" @@ -27,8 +28,9 @@ import ( // Import imports a slice of FDs into the given FDTable. If console is true, // sets up TTY for the first 3 FDs in the slice representing stdin, stdout, -// stderr. Upon success, Import takes ownership of all FDs. -func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { +// stderr. Used FDs are either closed or released. It's safe for the caller to +// close any remaining files upon return. +func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { if kernel.VFS2Enabled { ttyFile, err := importVFS2(ctx, fdTable, console, fds) return nil, ttyFile, err @@ -37,7 +39,7 @@ func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []in return ttyFile, nil, err } -func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) { +func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, error) { var ttyFile *fs.File for appFD, hostFD := range fds { var appFile *fs.File @@ -46,11 +48,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. // Remember this in the TTY file, as we will // use it for the other stdio FDs. @@ -65,11 +68,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] } else { // Import the file as a regular host file. var err error - appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. } // Add the file to the FD map. @@ -84,7 +88,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] return ttyFile.FileOperations.(*host.TTYFileOperations), nil } -func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) { +func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []*fd.FD) (*hostvfs2.TTYFileDescription, error) { k := kernel.KernelFromContext(ctx) if k == nil { return nil, fmt.Errorf("cannot find kernel from context") @@ -98,11 +102,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. // Remember this in the TTY file, as we will use it for the other stdio // FDs. @@ -115,11 +120,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi } } else { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. } if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil { diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index 2ca84dd74..05e043583 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -79,7 +79,7 @@ ops can be implemented in parallel. - Implement `/dev/fuse` - a character device used to establish an FD for communication between the sentry and the server daemon. -- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. +- Implement basic FUSE ops like `FUSE_INIT`. #### Read-only mount with basic file operations @@ -95,6 +95,103 @@ ops can be implemented in parallel. - Implement the remaining FUSE ops and decide if we can omit rarely used operations like ioctl. +### Design Details + +#### Lifecycle for a FUSE Request + +- User invokes a syscall +- Sentry prepares corresponding request + - If FUSE device is available + - Write the request in binary + - If FUSE device is full + - Kernel task blocked until available +- Sentry notifies the readers of fuse device that it's ready for read +- FUSE daemon reads the request and processes it +- Sentry waits until a reply is written to the FUSE device + - but returns directly for async requests +- FUSE daemon writes to the fuse device +- Sentry processes the reply + - For sync requests, unblock blocked kernel task + - For async requests, execute pre-specified callback if any +- Sentry returns the syscall to the user + +#### Channels and Queues for Requests in Different Stages + +`connection.initializedChan` + +- a channel that the requests issued before connection initialization blocks + on. + +`fd.queue` + +- a queue of requests that haven’t been read by the FUSE daemon yet. + +`fd.completions` + +- a map of the requests that have been prepared but not yet received a + response, including the ones on the `fd.queue`. + +`fd.waitQueue` + +- a queue of waiters that is waiting for the fuse device fd to be available, + such as the FUSE daemon. + +`fd.fullQueueCh` + +- a channel that the kernel task will be blocked on when the fd is not + available. + +#### Basic I/O Implementation + +Currently we have implemented basic functionalities of read and write for our +FUSE. We describe the design and ways to improve it here: + +##### Basic FUSE Read + +The vfs2 expects implementations of `vfs.FileDescriptionImpl.Read()` and +`vfs.FileDescriptionImpl.PRead()`. When a syscall is made, it will eventually +reach our implementation of those interface functions located at +`pkg/sentry/fsimpl/fuse/regular_file.go` for regular files. + +After validation checks of the input, sentry sends `FUSE_READ` requests to the +FUSE daemon. The FUSE daemon returns data after the `fuse_out_header` as the +responses. For the first version, we create a copy in kernel memory of those +data. They are represented as a byte slice in the marshalled struct. This +happens as a common process for all the FUSE responses at this moment at +`pkg/sentry/fsimpl/fuse/dev.go:writeLocked()`. We then directly copy from this +intermediate buffer to the input buffer provided by the read syscall. + +There is an extra requirement for FUSE: When mounting the FUSE fs, the mounter +or the FUSE daemon can specify a `max_read` or a `max_pages` parameter. They are +the upperbound of the bytes to read in each `FUSE_READ` request. We implemented +the code to handle the fragmented reads. + +To improve the performance: ideally we should have buffer cache to copy those +data from the responses of FUSE daemon into, as is also the design of several +other existing file system implementations for sentry, instead of a single-use +temporary buffer. Directly mapping the memory of one process to another could +also boost the performance, but to keep them isolated, we did not choose to do +so. + +##### Basic FUSE Write + +The vfs2 invokes implementations of `vfs.FileDescriptionImpl.Write()` and +`vfs.FileDescriptionImpl.PWrite()` on the regular file descriptor of FUSE when a +user makes write(2) and pwrite(2) syscall. + +For valid writes, sentry sends the bytes to write after a `FUSE_WRITE` header +(can be regarded as a request with 2 payloads) to the FUSE daemon. For the first +version, we allocate a buffer inside kernel memory to store the bytes from the +user, and copy directly from that buffer to the memory of FUSE daemon. This +happens at `pkg/sentry/fsimpl/fuse/dev.go:readLocked()` + +The parameters `max_write` and `max_pages` restrict the number of bytes in one +`FUSE_WRITE`. There are code handling fragmented writes in current +implementation. + +To have better performance: the extra copy created to store the bytes to write +can be replaced by the buffer cache as well. + # Appendix ## FUSE Protocol diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 42a6c41c2..1368014c4 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/fdnotifier", "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/secio", @@ -55,7 +56,6 @@ go_library( "//pkg/unet", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index 5d4f312cf..c8231e0aa 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -65,10 +65,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 67a807f9d..1183727ab 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -24,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -54,7 +54,7 @@ type TTYFileOperations struct { func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File { return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{ fileOperations: fileOperations{iops: iops}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, }) } diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index b79cd9877..004910453 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -270,7 +270,7 @@ func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, // SetXattr calls i.InodeOperations.SetXattr with i as the Inode. func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error { if i.overlay != nil { - return overlaySetxattr(ctx, i.overlay, d, name, value, flags) + return overlaySetXattr(ctx, i.overlay, d, name, value, flags) } return i.InodeOperations.SetXattr(ctx, i, name, value, flags) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index dc2e353d9..b16ab08ba 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -16,7 +16,6 @@ package fs import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -539,7 +538,7 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin // Don't forward the value of the extended attribute if it would // unexpectedly change the behavior of a wrapping overlay layer. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return "", syserror.ENODATA } @@ -553,9 +552,9 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin return s, err } -func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { +func overlaySetXattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { // Don't allow changes to overlay xattrs through a setxattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } @@ -578,7 +577,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st for name := range names { // Same as overlayGetXattr, we shouldn't forward along // overlay attributes. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { delete(names, name) } } @@ -587,7 +586,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error { // Don't allow changes to overlay xattrs through a removexattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 77c2c5c0e..b8b2281a8 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 8615b60f0..e555672ad 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,7 +55,7 @@ type tcpMemInode struct { // size stores the tcp buffer size during save, and sets the buffer // size in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. + // a netstack instance is created on restore. size inet.TCPBufferSize // mu protects against concurrent reads/writes to files based on this @@ -258,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque if src.NumBytes() == 0 { return 0, nil } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -383,11 +387,125 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } +// ipForwarding implements fs.InodeOperations. +// +// ipForwarding is used to enable/disable packet forwarding of netstack. +// +// +stateify savable +type ipForwarding struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // enabled stores the IPv4 forwarding state on save. + // We must save/restore this here, since a netstack instance + // is created on restore. + enabled *bool +} + +func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &ipForwarding{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*ipForwarding) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type ipForwardingFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + ipf *ipForwarding + + stack inet.Stack `state:"wait"` +} + +// GetFile implements fs.InodeOperations.GetFile. +func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ + stack: ipf.stack, + ipf: ipf, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if f.ipf.enabled == nil { + enabled := f.stack.Forwarding(ipv4.ProtocolNumber) + f.ipf.enabled = &enabled + } + + val := "0\n" + if *f.ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux + // stores these as an integer, so if you write "2" into + // ip_forward, you should get 2 back. + val = "1\n" + } + n, err := dst.CopyOut(ctx, []byte(val)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return n, err + } + if f.ipf.enabled == nil { + f.ipf.enabled = new(bool) + } + *f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. "tcp_sack": newTCPSackInode(ctx, msrc, s), + // Add ip_forward. + "ip_forward": newIPForwardingInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 6eba709c6..4cb4741af 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -14,7 +14,11 @@ package proc -import "fmt" +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" +) // beforeSave is invoked by stateify. func (t *tcpMemInode) beforeSave() { @@ -40,3 +44,12 @@ func (s *tcpSack) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (ipf *ipForwarding) afterLoad() { + if ipf.enabled != nil { + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) + } + } +} diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 355e83d47..6ef5738e7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -123,3 +123,76 @@ func TestConfigureRecvBufferSize(t *testing.T) { } } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + ipf := &ipForwarding{stack: s} + file := &ipForwardingFile{ + stack: s, + ipf: ipf, + } + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + + }) + } +} diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 9cf7f2a62..22d658acf 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -84,6 +84,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "auxv": newAuxvec(t, msrc), "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), "comm": newComm(t, msrc), + "cwd": newCwd(t, msrc), "environ": newExecArgInode(t, msrc, environExecArg), "exe": newExe(t, msrc), "fd": newFdDir(t, msrc), @@ -300,6 +301,49 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { return exec.PathnameWithDeleted(ctx), nil } +// cwd is an fs.InodeOperations symlink for the /proc/PID/cwd file. +// +// +stateify savable +type cwd struct { + ramfs.Symlink + + t *kernel.Task +} + +func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + cwdSymlink := &cwd{ + Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + t: t, + } + return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t) +} + +// Readlink implements fs.InodeOperations. +func (e *cwd) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { + if !kernel.ContextCanTrace(ctx, e.t, false) { + return "", syserror.EACCES + } + if err := checkTaskState(e.t); err != nil { + return "", err + } + cwd := e.t.FSContext().WorkingDirectory() + if cwd == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer cwd.DecRef(ctx) + + root := fs.RootFromContext(ctx) + if root == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + name, _ := cwd.FullName(root) + return name, nil +} + // namespaceSymlink represents a symlink in the namespacefs, such as the files // in /proc/<pid>/ns. // @@ -604,7 +648,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( var vss, rss, data uint64 s.t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5cb0e0417..e6d0eb359 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -10,13 +10,14 @@ go_library( "line_discipline.go", "master.go", "queue.go", - "slave.go", + "replica.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 463f6189e..c2da80bc2 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -37,14 +37,14 @@ import ( // This indirectly manages all terminals within the mount. // // New Terminals are created by masterInodeOperations.GetFile, which registers -// the slave Inode in the this directory for discovery via Lookup/Readdir. The -// slave inode is unregistered when the master file is Released, as the slave +// the replica Inode in the this directory for discovery via Lookup/Readdir. The +// replica inode is unregistered when the master file is Released, as the replica // is no longer discoverable at that point. // // References on the underlying Terminal are held by masterFileOperations and -// slaveInodeOperations. +// replicaInodeOperations. // -// masterInodeOperations and slaveInodeOperations hold a pointer to +// masterInodeOperations and replicaInodeOperations hold a pointer to // dirInodeOperations, which is reference counted by the refcount their // corresponding Dirents hold on their parent (this directory). // @@ -76,16 +76,16 @@ type dirInodeOperations struct { // master is the master PTY inode. master *fs.Inode - // slaves contains the slave inodes reachable from the directory. + // replicas contains the replica inodes reachable from the directory. // - // A new slave is added by allocateTerminal and is removed by + // A new replica is added by allocateTerminal and is removed by // masterFileOperations.Release. // - // A reference is held on every slave in the map. - slaves map[uint32]*fs.Inode + // A reference is held on every replica in the map. + replicas map[uint32]*fs.Inode // dentryMap is a SortedDentryMap used to implement Readdir containing - // the master and all entries in slaves. + // the master and all entries in replicas. dentryMap *fs.SortedDentryMap // next is the next pty index to use. @@ -101,7 +101,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { d := &dirInodeOperations{ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC), msrc: m, - slaves: make(map[uint32]*fs.Inode), + replicas: make(map[uint32]*fs.Inode), dentryMap: fs.NewSortedDentryMap(nil), } // Linux devpts uses a default mode of 0000 for ptmx which can be @@ -133,7 +133,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) { defer d.mu.Unlock() d.master.DecRef(ctx) - if len(d.slaves) != 0 { + if len(d.replicas) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) } } @@ -149,14 +149,14 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str return fs.NewDirent(ctx, d.master, name), nil } - // Slave number? + // Replica number? n, err := strconv.ParseUint(name, 10, 32) if err != nil { // Not found. return nil, syserror.ENOENT } - s, ok := d.slaves[uint32(n)] + s, ok := d.replicas[uint32(n)] if !ok { return nil, syserror.ENOENT } @@ -236,7 +236,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e return nil, syserror.ENOMEM } - if _, ok := d.slaves[n]; ok { + if _, ok := d.replicas[n]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", n)) } @@ -244,19 +244,19 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e d.next++ // The reference returned by newTerminal is returned to the caller. - // Take another for the slave inode. + // Take another for the replica inode. t.IncRef() // Create a pts node. The owner is based on the context that opens // ptmx. creds := auth.CredentialsFromContext(ctx) uid, gid := creds.EffectiveKUID, creds.EffectiveKGID - slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) + replica := newReplicaInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) - d.slaves[n] = slave + d.replicas[n] = replica d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{ - Type: slave.StableAttr.Type, - InodeID: slave.StableAttr.InodeID, + Type: replica.StableAttr.Type, + InodeID: replica.StableAttr.InodeID, }) return t, nil @@ -267,18 +267,18 @@ func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) { d.mu.Lock() defer d.mu.Unlock() - // The slave end disappears from the directory when the master end is - // closed, even if the slave end is open elsewhere. + // The replica end disappears from the directory when the master end is + // closed, even if the replica end is open elsewhere. // // N.B. since we're using a backdoor method to remove a directory entry // we won't properly fire inotify events like Linux would. - s, ok := d.slaves[t.n] + s, ok := d.replicas[t.n] if !ok { panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d)) } s.DecRef(ctx) - delete(d.slaves, t.n) + delete(d.replicas, t.n) d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10)) } diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index 2d4d44bf3..13f4901db 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -79,8 +79,8 @@ type superOperations struct{} // // It always returns true, forcing a Lookup for all entries. // -// Slave entries are dropped from dir when their master is closed, so an -// existing slave Dirent in the tree is not sufficient to guarantee that it +// Replica entries are dropped from dir when their master is closed, so an +// existing replica Dirent in the tree is not sufficient to guarantee that it // still exists on the filesystem. func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool { return true diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 2e9dd2d55..b34f4a0eb 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -43,7 +44,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -54,8 +55,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -64,7 +65,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -103,8 +104,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -115,27 +116,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -146,27 +143,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -176,14 +169,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -196,7 +189,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -211,14 +204,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -229,7 +222,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index e00746017..b91184b1b 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -152,46 +154,51 @@ func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src userm // Ioctl implements fs.FileOperations.Ioctl. func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mf.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mf.t.ld.outputQueueReadSize(t, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mf.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mf.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mf.t.ld.setTermios(ctx, io, args) + return mf.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mf.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mf.t.ld.windowSize(ctx, io, args) + return 0, mf.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mf.t.ld.setWindowSize(ctx, io, args) + return 0, mf.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index c5d7ec717..79975d812 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -17,8 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -32,7 +34,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -85,17 +87,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/replica.go index 7c7292687..385d230fb 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/replica.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -27,11 +29,11 @@ import ( // LINT.IfChange -// slaveInodeOperations are the fs.InodeOperations for the slave end of the +// replicaInodeOperations are the fs.InodeOperations for the replica end of the // Terminal (pts file). // // +stateify savable -type slaveInodeOperations struct { +type replicaInodeOperations struct { fsutil.SimpleFileInode // d is the containing dir. @@ -41,13 +43,13 @@ type slaveInodeOperations struct { t *Terminal } -var _ fs.InodeOperations = (*slaveInodeOperations)(nil) +var _ fs.InodeOperations = (*replicaInodeOperations)(nil) -// newSlaveInode creates an fs.Inode for the slave end of a terminal. +// newReplicaInode creates an fs.Inode for the replica end of a terminal. // -// newSlaveInode takes ownership of t. -func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { - iops := &slaveInodeOperations{ +// newReplicaInode takes ownership of t. +func newReplicaInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { + iops := &replicaInodeOperations{ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC), d: d, t: t, @@ -64,18 +66,18 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne Type: fs.CharacterDevice, // See fs/devpts/inode.c:devpts_fill_super. BlockSize: 1024, - DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR, + DeviceFileMajor: linux.UNIX98_PTY_REPLICA_MAJOR, DeviceFileMinor: t.n, }) } // Release implements fs.InodeOperations.Release. -func (si *slaveInodeOperations) Release(ctx context.Context) { +func (si *replicaInodeOperations) Release(ctx context.Context) { si.t.DecRef(ctx) } // Truncate implements fs.InodeOperations.Truncate. -func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { +func (*replicaInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { return nil } @@ -83,14 +85,15 @@ func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { // // This may race with destruction of the terminal. If the terminal is gone, it // returns ENOENT. -func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil +func (si *replicaInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &replicaFileOperations{si: si}), nil } -// slaveFileOperations are the fs.FileOperations for the slave end of a terminal. +// replicaFileOperations are the fs.FileOperations for the replica end of a +// terminal. // // +stateify savable -type slaveFileOperations struct { +type replicaFileOperations struct { fsutil.FilePipeSeek `state:"nosave"` fsutil.FileNotDirReaddir `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` @@ -100,79 +103,84 @@ type slaveFileOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` // si is the inode operations. - si *slaveInodeOperations + si *replicaInodeOperations } -var _ fs.FileOperations = (*slaveFileOperations)(nil) +var _ fs.FileOperations = (*replicaFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (sf *slaveFileOperations) Release(context.Context) { +func (sf *replicaFileOperations) Release(context.Context) { } // EventRegister implements waiter.Waitable.EventRegister. -func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sf.si.t.ld.slaveWaiter.EventRegister(e, mask) +func (sf *replicaFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + sf.si.t.ld.replicaWaiter.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) { - sf.si.t.ld.slaveWaiter.EventUnregister(e) +func (sf *replicaFileOperations) EventUnregister(e *waiter.Entry) { + sf.si.t.ld.replicaWaiter.EventUnregister(e) } // Readiness implements waiter.Waitable.Readiness. -func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return sf.si.t.ld.slaveReadiness() +func (sf *replicaFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return sf.si.t.ld.replicaReadiness() } // Read implements fs.FileOperations.Read. -func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.inputQueueRead(ctx, dst) } // Write implements fs.FileOperations.Write. -func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.outputQueueWrite(ctx, src) } // Ioctl implements fs.FileOperations.Ioctl. -func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (sf *replicaFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the input queue read buffer. - return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args) + return 0, sf.si.t.ld.inputQueueReadSize(t, args) case linux.TCGETS: - return sf.si.t.ld.getTermios(ctx, io, args) + return sf.si.t.ld.getTermios(t, args) case linux.TCSETS: - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(sf.si.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCGWINSZ: - return 0, sf.si.t.ld.windowSize(ctx, io, args) + return 0, sf.si.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, sf.si.t.ld.setWindowSize(ctx, io, args) + return 0, sf.si.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.setControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.releaseControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.foregroundProcessGroup(ctx, args, false /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY } } -// LINT.ThenChange(../../fsimpl/devpts/slave.go) +// LINT.ThenChange(../../fsimpl/devpts/replica.go) diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ddcccf4da..4f431d74d 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,10 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -44,19 +44,19 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t @@ -64,7 +64,7 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -85,7 +85,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -97,24 +97,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -126,7 +123,7 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } // LINT.ThenChange(../../fsimpl/devpts/terminal.go) diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go index 2cbc05678..49edee83d 100644 --- a/pkg/sentry/fs/tty/tty_test.go +++ b/pkg/sentry/fs/tty/tty_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 3f64fab3a..48e13613a 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -21,8 +21,8 @@ go_library( "line_discipline.go", "master.go", "queue.go", + "replica.go", "root_inode_refs.go", - "slave.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], @@ -30,6 +30,8 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 57580f4d4..903135fae 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -35,6 +35,8 @@ import ( const Name = "devpts" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -58,6 +60,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -79,7 +82,7 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds // Construct the root directory. This is always inode id 1. root := &rootInode{ - slaves: make(map[uint32]*slaveInode), + replicas: make(map[uint32]*replicaInode), } root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) @@ -110,6 +113,8 @@ func (fs *filesystem) Release(ctx context.Context) { } // rootInode is the root directory inode for the devpts mounts. +// +// +stateify savable type rootInode struct { implStatFS kernfs.AlwaysValid @@ -131,10 +136,10 @@ type rootInode struct { root *rootInode // mu protects the fields below. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` - // slaves maps pty ids to slave inodes. - slaves map[uint32]*slaveInode + // replicas maps pty ids to replica inodes. + replicas map[uint32]*replicaInode // nextIdx is the next pty index to use. Must be accessed atomically. // @@ -154,22 +159,22 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) idx := i.nextIdx i.nextIdx++ - // Sanity check that slave with idx does not exist. - if _, ok := i.slaves[idx]; ok { + // Sanity check that replica with idx does not exist. + if _, ok := i.replicas[idx]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", idx)) } - // Create the new terminal and slave. + // Create the new terminal and replica. t := newTerminal(idx) - slave := &slaveInode{ + replica := &replicaInode{ root: i, t: t, } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) - slave.dentry.Init(slave) - i.slaves[idx] = slave + replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.dentry.Init(replica) + i.replicas[idx] = replica return t, nil } @@ -179,16 +184,16 @@ func (i *rootInode) masterClose(t *Terminal) { i.mu.Lock() defer i.mu.Unlock() - // Sanity check that slave with idx exists. - if _, ok := i.slaves[t.n]; !ok { + // Sanity check that replica with idx exists. + if _, ok := i.replicas[t.n]; !ok { panic(fmt.Sprintf("pty with index %d does not exist", t.n)) } - delete(i.slaves, t.n) + delete(i.replicas, t.n) } // Open implements kernfs.Inode.Open. -func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ +func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndStaticEntries, }) if err != nil { @@ -198,16 +203,16 @@ func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D } // Lookup implements kernfs.Inode.Lookup. -func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (i *rootInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { idx, err := strconv.ParseUint(name, 10, 32) if err != nil { return nil, syserror.ENOENT } i.mu.Lock() defer i.mu.Unlock() - if si, ok := i.slaves[uint32(idx)]; ok { + if si, ok := i.replicas[uint32(idx)]; ok { si.dentry.IncRef() - return si.dentry.VFSDentry(), nil + return &si.dentry, nil } return nil, syserror.ENOENT @@ -217,8 +222,8 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() - ids := make([]int, 0, len(i.slaves)) - for id := range i.slaves { + ids := make([]int, 0, len(i.replicas)) + for id := range i.replicas { ids = append(ids, int(id)) } sort.Ints(ids) @@ -226,7 +231,7 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(id), 10), Type: linux.DT_CHR, - Ino: i.slaves[uint32(id)].InodeAttrs.Ino(), + Ino: i.replicas[uint32(id)].InodeAttrs.Ino(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -237,11 +242,12 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, return offset, nil } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *rootInode) DecRef(context.Context) { i.rootInodeRefs.DecRef(i.Destroy) } +// +stateify savable type implStatFS struct{} // StatFS implements kernfs.Inode.StatFS. diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go index b7c149047..448390cfe 100644 --- a/pkg/sentry/fsimpl/devpts/devpts_test.go +++ b/pkg/sentry/fsimpl/devpts/devpts_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index f7bc325d1..e6b0e81cf 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -41,7 +42,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -52,8 +53,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -62,7 +63,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -101,8 +102,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -113,27 +114,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -144,27 +141,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -174,14 +167,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, io, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -194,7 +187,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -209,14 +202,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, io, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -227,7 +220,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 60feb1993..69c2fe951 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -17,9 +17,11 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -29,6 +31,8 @@ import ( ) // masterInode is the inode for the master end of the Terminal. +// +// +stateify savable type masterInode struct { implStatFS kernfs.InodeAttrs @@ -48,20 +52,18 @@ type masterInode struct { var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. -func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { t, err := mi.root.allocateTerminal(rp.Credentials()) if err != nil { return nil, err } - mi.IncRef() fd := &masterFileDescription{ inode: mi, t: t, } fd.LockFD.Init(&mi.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - mi.DecRef(ctx) + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil @@ -87,6 +89,7 @@ func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) } +// +stateify savable type masterFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -101,7 +104,6 @@ var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil) // Release implements vfs.FileDescriptionImpl.Release. func (mfd *masterFileDescription) Release(ctx context.Context) { mfd.inode.root.masterClose(mfd.t) - mfd.inode.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. @@ -131,46 +133,51 @@ func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSeque // Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mfd.t.ld.outputQueueReadSize(t, io, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mfd.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mfd.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mfd.t.ld.setTermios(ctx, io, args) + return mfd.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mfd.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mfd.t.ld.windowSize(ctx, io, args) + return 0, mfd.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mfd.t.ld.setWindowSize(ctx, io, args) + return 0, mfd.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index 331c13997..55bff3e60 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -17,8 +17,10 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -30,7 +32,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -83,17 +85,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go new file mode 100644 index 000000000..6515c5536 --- /dev/null +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -0,0 +1,204 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devpts + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// replicaInode is the inode for the replica end of the Terminal. +// +// +stateify savable +type replicaInode struct { + implStatFS + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + locks vfs.FileLocks + + // Keep a reference to this inode's dentry. + dentry kernfs.Dentry + + // root is the devpts root inode. + root *rootInode + + // t is the connected Terminal. + t *Terminal +} + +var _ kernfs.Inode = (*replicaInode)(nil) + +// Open implements kernfs.Inode.Open. +func (ri *replicaInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &replicaFileDescription{ + inode: ri, + } + fd.LockFD.Init(&ri.locks) + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return &fd.vfsfd, nil + +} + +// Valid implements kernfs.Inode.Valid. +func (ri *replicaInode) Valid(context.Context) bool { + // Return valid if the replica still exists. + ri.root.mu.Lock() + defer ri.root.mu.Unlock() + _, ok := ri.root.replicas[ri.t.n] + return ok +} + +// Stat implements kernfs.Inode.Stat. +func (ri *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := ri.InodeAttrs.Stat(ctx, vfsfs, opts) + if err != nil { + return linux.Statx{}, err + } + statx.Blksize = 1024 + statx.RdevMajor = linux.UNIX98_PTY_REPLICA_MAJOR + statx.RdevMinor = ri.t.n + return statx, nil +} + +// SetStat implements kernfs.Inode.SetStat +func (ri *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + if opts.Stat.Mask&linux.STATX_SIZE != 0 { + return syserror.EINVAL + } + return ri.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) +} + +// +stateify savable +type replicaFileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *replicaInode +} + +var _ vfs.FileDescriptionImpl = (*replicaFileDescription)(nil) + +// Release implements fs.FileOperations.Release. +func (rfd *replicaFileDescription) Release(ctx context.Context) {} + +// EventRegister implements waiter.Waitable.EventRegister. +func (rfd *replicaFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + rfd.inode.t.ld.replicaWaiter.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (rfd *replicaFileDescription) EventUnregister(e *waiter.Entry) { + rfd.inode.t.ld.replicaWaiter.EventUnregister(e) +} + +// Readiness implements waiter.Waitable.Readiness. +func (rfd *replicaFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { + return rfd.inode.t.ld.replicaReadiness() +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (rfd *replicaFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + return rfd.inode.t.ld.inputQueueRead(ctx, dst) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (rfd *replicaFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + return rfd.inode.t.ld.outputQueueWrite(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (rfd *replicaFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + + switch cmd := args[1].Uint(); cmd { + case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ + // Get the number of bytes in the input queue read buffer. + return 0, rfd.inode.t.ld.inputQueueReadSize(t, io, args) + case linux.TCGETS: + return rfd.inode.t.ld.getTermios(t, args) + case linux.TCSETS: + return rfd.inode.t.ld.setTermios(t, args) + case linux.TCSETSW: + // TODO(b/29356795): This should drain the output queue first. + return rfd.inode.t.ld.setTermios(t, args) + case linux.TIOCGPTN: + nP := primitive.Uint32(rfd.inode.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) + return 0, err + case linux.TIOCGWINSZ: + return 0, rfd.inode.t.ld.windowSize(t, args) + case linux.TIOCSWINSZ: + return 0, rfd.inode.t.ld.setWindowSize(t, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, rfd.inode.t.setControllingTTY(ctx, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, rfd.inode.t.releaseControllingTTY(ctx, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return rfd.inode.t.foregroundProcessGroup(ctx, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return rfd.inode.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) + default: + maybeEmitUnimplementedEvent(ctx, cmd) + return 0, syserror.ENOTTY + } +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (rfd *replicaFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + creds := auth.CredentialsFromContext(ctx) + fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() + return rfd.inode.SetStat(ctx, fs, creds, opts) +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (rfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := rfd.vfsfd.VirtualDentry().Mount().Filesystem() + return rfd.inode.Stat(ctx, fs, opts) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (rfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return rfd.Locks().LockPOSIX(ctx, &rfd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (rfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return rfd.Locks().UnlockPOSIX(ctx, &rfd.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go deleted file mode 100644 index a9da7af64..000000000 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devpts - -import ( - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" - fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" - "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" -) - -// slaveInode is the inode for the slave end of the Terminal. -type slaveInode struct { - implStatFS - kernfs.InodeAttrs - kernfs.InodeNoopRefCount - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - - locks vfs.FileLocks - - // Keep a reference to this inode's dentry. - dentry kernfs.Dentry - - // root is the devpts root inode. - root *rootInode - - // t is the connected Terminal. - t *Terminal -} - -var _ kernfs.Inode = (*slaveInode)(nil) - -// Open implements kernfs.Inode.Open. -func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - si.IncRef() - fd := &slaveFileDescription{ - inode: si, - } - fd.LockFD.Init(&si.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { - si.DecRef(ctx) - return nil, err - } - return &fd.vfsfd, nil - -} - -// Valid implements kernfs.Inode.Valid. -func (si *slaveInode) Valid(context.Context) bool { - // Return valid if the slave still exists. - si.root.mu.Lock() - defer si.root.mu.Unlock() - _, ok := si.root.slaves[si.t.n] - return ok -} - -// Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts) - if err != nil { - return linux.Statx{}, err - } - statx.Blksize = 1024 - statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR - statx.RdevMinor = si.t.n - return statx, nil -} - -// SetStat implements kernfs.Inode.SetStat -func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - if opts.Stat.Mask&linux.STATX_SIZE != 0 { - return syserror.EINVAL - } - return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) -} - -type slaveFileDescription struct { - vfsfd vfs.FileDescription - vfs.FileDescriptionDefaultImpl - vfs.LockFD - - inode *slaveInode -} - -var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil) - -// Release implements fs.FileOperations.Release. -func (sfd *slaveFileDescription) Release(ctx context.Context) { - sfd.inode.DecRef(ctx) -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) { - sfd.inode.t.ld.slaveWaiter.EventUnregister(e) -} - -// Readiness implements waiter.Waitable.Readiness. -func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { - return sfd.inode.t.ld.slaveReadiness() -} - -// Read implements vfs.FileDescriptionImpl.Read. -func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { - return sfd.inode.t.ld.inputQueueRead(ctx, dst) -} - -// Write implements vfs.FileDescriptionImpl.Write. -func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { - return sfd.inode.t.ld.outputQueueWrite(ctx, src) -} - -// Ioctl implements vfs.FileDescriptionImpl.Ioctl. -func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - switch cmd := args[1].Uint(); cmd { - case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ - // Get the number of bytes in the input queue read buffer. - return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args) - case linux.TCGETS: - return sfd.inode.t.ld.getTermios(ctx, io, args) - case linux.TCSETS: - return sfd.inode.t.ld.setTermios(ctx, io, args) - case linux.TCSETSW: - // TODO(b/29356795): This should drain the output queue first. - return sfd.inode.t.ld.setTermios(ctx, io, args) - case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - case linux.TIOCGWINSZ: - return 0, sfd.inode.t.ld.windowSize(ctx, io, args) - case linux.TIOCSWINSZ: - return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args) - case linux.TIOCSCTTY: - // Make the given terminal the controlling terminal of the - // calling process. - return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */) - case linux.TIOCNOTTY: - // Release this process's controlling terminal. - return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) - case linux.TIOCGPGRP: - // Get the foreground process group. - return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) - case linux.TIOCSPGRP: - // Set the foreground process group. - return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) - default: - maybeEmitUnimplementedEvent(ctx, cmd) - return 0, syserror.ENOTTY - } -} - -// SetStat implements vfs.FileDescriptionImpl.SetStat. -func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - creds := auth.CredentialsFromContext(ctx) - fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.SetStat(ctx, fs, creds, opts) -} - -// Stat implements vfs.FileDescriptionImpl.Stat. -func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.Stat(ctx, fs, opts) -} - -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block) -} - -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence) -} diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go index 7d2781c54..510bd6d89 100644 --- a/pkg/sentry/fsimpl/devpts/terminal.go +++ b/pkg/sentry/fsimpl/devpts/terminal.go @@ -17,9 +17,9 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // Terminal is a pseudoterminal. @@ -36,25 +36,25 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } return &t } // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -65,7 +65,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -87,24 +87,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -116,5 +113,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index 52f44f66d..6d1753080 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -33,8 +33,10 @@ import ( const Name = "devtmpfs" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct { - initOnce sync.Once + initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported. initErr error // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. @@ -80,7 +82,7 @@ type Accessor struct { // NewAccessor returns an Accessor that supports creation of device special // files in the devtmpfs instance registered with name fsTypeName in vfsObj. func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) { - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.MountOptions{}) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go index 827a608cb..3a38b8bb4 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go @@ -48,7 +48,7 @@ func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.Virtu }) // Create a test mount namespace with devtmpfs mounted at "/dev". - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 812171fa3..1c27ad700 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -30,9 +30,11 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// EventFileDescription implements FileDescriptionImpl for file-based event +// EventFileDescription implements vfs.FileDescriptionImpl for file-based event // notification (eventfd). Eventfds are usually internal to the Sentry but in // certain situations they may be converted into a host-backed eventfd. +// +// +stateify savable type EventFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -106,7 +108,7 @@ func (efd *EventFileDescription) HostFD() (int, error) { return efd.hostfd, nil } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (efd *EventFileDescription) Release(context.Context) { efd.mu.Lock() defer efd.mu.Unlock() @@ -119,7 +121,7 @@ func (efd *EventFileDescription) Release(context.Context) { } } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { if dst.NumBytes() < 8 { return 0, syscall.EINVAL @@ -130,7 +132,7 @@ func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequenc return 8, nil } -// Write implements FileDescriptionImpl.Write. +// Write implements vfs.FileDescriptionImpl.Write. func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { if src.NumBytes() < 8 { return 0, syscall.EINVAL diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index abc610ef3..7b1eec3da 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/fd", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", @@ -86,9 +88,9 @@ go_test( library = ":ext", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index a2cc9b59f..c349b886e 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -59,7 +59,11 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index 8bb104ff0..1165234f9 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -18,7 +18,7 @@ import ( "io" "math" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" ) @@ -34,19 +34,19 @@ type blockMapFile struct { // directBlks are the direct blocks numbers. The physical blocks pointed by // these holds file data. Contains file blocks 0 to 11. - directBlks [numDirectBlks]uint32 + directBlks [numDirectBlks]primitive.Uint32 // indirectBlk is the physical block which contains (blkSize/4) direct block // numbers (as uint32 integers). - indirectBlk uint32 + indirectBlk primitive.Uint32 // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect // block numbers (as uint32 integers). - doubleIndirectBlk uint32 + doubleIndirectBlk primitive.Uint32 // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly // indirect block numbers (as uint32 integers). - tripleIndirectBlk uint32 + tripleIndirectBlk primitive.Uint32 // coverage at (i)th index indicates the amount of file data a node at // height (i) covers. Height 0 is the direct block. @@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) { } blkMap := file.regFile.inode.diskInode.Data() - binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) - binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) + for i := 0; i < numDirectBlks; i++ { + file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4]) + } + file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4]) + file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4]) + file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4]) return file, nil } @@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { switch { case offset < dirBlksEnd: // Direct block. - curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) + curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:]) case offset < indirBlkEnd: // Indirect block. - curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) + curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:]) case offset < doubIndirBlkEnd: // Doubly indirect block. - curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) + curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:]) default: // Triply indirect block. - curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) + curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:]) } read += curR @@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds read := 0 curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { - var childPhyBlk uint32 + var childPhyBlk primitive.Uint32 err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } - n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) + n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:]) read += n if err != nil { return read, err diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 6fa84e7aa..ed98b482e 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) var fileData []byte blkNums := newBlkNumGen() - var data []byte + off := 0 + data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes()) // Write the direct blocks. for i := 0; i < numDirectBlks; i++ { - curBlkNum := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, curBlkNum) - fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(data[off:]) + off += curBlkNum.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...) } // Write to indirect block. - indirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, indirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) - - // Write to indirect block. - doublyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) - - // Write to indirect block. - triplyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) + indirectBlk := primitive.Uint32(blkNums.next()) + indirectBlk.MarshalBytes(data[off:]) + off += indirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...) + + // Write to double indirect block. + doublyIndirectBlk := primitive.Uint32(blkNums.next()) + doublyIndirectBlk.MarshalBytes(data[off:]) + off += doublyIndirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...) + + // Write to triple indirect block. + triplyIndirectBlk := primitive.Uint32(blkNums.next()) + triplyIndirectBlk.MarshalBytes(data[off:]) + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...) args := inodeArgs{ fs: &filesystem{ @@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN var fileData []byte for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { - curBlkNum := blkNums.next() - copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) - fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(disk[off : off+4]) + fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...) } return fileData } diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 7a1b4219f..9bfed883a 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -20,6 +20,8 @@ import ( ) // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 0fc01668d..0ad79b381 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -16,7 +16,6 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -28,6 +27,8 @@ import ( ) // directory represents a directory inode. It holds the childList in memory. +// +// +stateify savable type directory struct { inode inode @@ -39,7 +40,7 @@ type directory struct { // Lock Order (outermost locks must be taken first): // directory.mu // filesystem.mu - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // childList is a list containing (1) child dirents and (2) fake dirents // (with diskDirent == nil) that represent the iteration position of @@ -98,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) { } else { curDirent.diskDirent = &disklayout.DirentOld{} } - binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) + curDirent.diskDirent.UnmarshalBytes(buf) if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { // Inode number and name length fields being set to 0 is used to indicate @@ -120,6 +121,8 @@ func (i *inode) isDir() bool { } // dirent is the directory.childList node. +// +// +stateify savable type dirent struct { diskDirent disklayout.Dirent @@ -129,6 +132,8 @@ type dirent struct { // directoryFD represents a directory file description. It implements // vfs.FileDescriptionImpl. +// +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 9bd9c76c0..d98a05dd8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -22,10 +22,11 @@ go_library( "superblock_old.go", "test_utils.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/marshal", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go index ad6f4fef8..0d56ae9da 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // BlockGroup represents a Linux ext block group descriptor. An ext file system // is split into a series of block groups. This provides an access layer to // information needed to access and use a block group. @@ -30,6 +34,8 @@ package disklayout // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors. type BlockGroup interface { + marshal.Marshallable + // InodeTable returns the absolute block number of the block containing the // inode table. This points to an array of Inode structs. Inode tables are // statically allocated at mkfs time. The superblock records the number of diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go index 3e16c76db..a35fa22a0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go @@ -17,6 +17,8 @@ package disklayout // BlockGroup32Bit emulates the first half of struct ext4_group_desc in // fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and // 32-bit ext4 filesystems. It implements BlockGroup interface. +// +// +marshal type BlockGroup32Bit struct { BlockBitmapLo uint32 InodeBitmapLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go index 9a809197a..d54d1d345 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go @@ -18,6 +18,8 @@ package disklayout // It is the block group descriptor struct for 64-bit ext4 filesystems. // It implements BlockGroup interface. It is an extension of the 32-bit // version of BlockGroup. +// +// +marshal type BlockGroup64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go index 0ef4294c0..e4ce484e4 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go @@ -21,6 +21,8 @@ import ( // TestBlockGroupSize tests that the block group descriptor structs are of the // correct size. func TestBlockGroupSize(t *testing.T) { - assertSize(t, BlockGroup32Bit{}, 32) - assertSize(t, BlockGroup64Bit{}, 64) + var bgSmall BlockGroup32Bit + assertSize(t, &bgSmall, 32) + var bgBig BlockGroup64Bit + assertSize(t, &bgBig, 64) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go index 417b6cf65..568c8cb4c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go @@ -15,6 +15,7 @@ package disklayout import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fs" ) @@ -51,6 +52,8 @@ var ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories. type Dirent interface { + marshal.Marshallable + // Inode returns the absolute inode number of the underlying inode. // Inode number 0 signifies an unused dirent. Inode() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go index 29ae4a5c2..51f9c2946 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go @@ -29,12 +29,14 @@ import ( // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentNew struct { InodeNumber uint32 RecordLength uint16 NameLength uint8 FileTypeRaw uint8 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentNew implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go index 6fff12a6e..d4b19e086 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go @@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs" // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentOld struct { InodeNumber uint32 RecordLength uint16 NameLength uint16 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentOld implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go index 934919f8a..3486864dc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go @@ -21,6 +21,8 @@ import ( // TestDirentSize tests that the dirent structs are of the correct // size. func TestDirentSize(t *testing.T) { - assertSize(t, DirentOld{}, uintptr(DirentSize)) - assertSize(t, DirentNew{}, uintptr(DirentSize)) + var dOld DirentOld + assertSize(t, &dOld, DirentSize) + var dNew DirentNew + assertSize(t, &dNew, DirentSize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go index bdf4e2132..0834e9ba8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go +++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go @@ -36,8 +36,6 @@ // escape analysis on an unknown implementation at compile time. // // Notes: -// - All fields in these structs are exported because binary.Read would -// panic otherwise. // - All structures on disk are in little-endian order. Only jbd2 (journal) // structures are in big-endian order. // - All OS dependent fields in these structures will be interpretted using diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 4110649ab..b13999bfc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // Extents were introduced in ext4 and provide huge performance gains in terms // data locality and reduced metadata block usage. Extents are organized in // extent trees. The root node is contained in inode.BlocksRaw. @@ -64,6 +68,8 @@ type ExtentNode struct { // ExtentEntry represents an extent tree node entry. The entry can either be // an ExtentIdx or Extent itself. This exists to simplify navigation logic. type ExtentEntry interface { + marshal.Marshallable + // FileBlock returns the first file block number covered by this entry. FileBlock() uint32 @@ -75,6 +81,8 @@ type ExtentEntry interface { // tree node begins with this and is followed by `NumEntries` number of: // - Extent if `Depth` == 0 // - ExtentIdx otherwise +// +// +marshal type ExtentHeader struct { // Magic in the extent magic number, must be 0xf30a. Magic uint16 @@ -96,6 +104,8 @@ type ExtentHeader struct { // internal nodes. Sorted in ascending order based on FirstFileBlock since // Linux does a binary search on this. This points to a block containing the // child node. +// +// +marshal type ExtentIdx struct { FirstFileBlock uint32 ChildBlockLo uint32 @@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 { // nodes. Sorted in ascending order based on FirstFileBlock since Linux does a // binary search on this. This points to an array of data blocks containing the // file data. It covers `Length` data blocks starting from `StartBlock`. +// +// +marshal type Extent struct { FirstFileBlock uint32 Length uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index 8762b90db..c96002e19 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go @@ -21,7 +21,10 @@ import ( // TestExtentSize tests that the extent structs are of the correct // size. func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentHeaderSize) - assertSize(t, ExtentIdx{}, ExtentEntrySize) - assertSize(t, Extent{}, ExtentEntrySize) + var h ExtentHeader + assertSize(t, &h, ExtentHeaderSize) + var i ExtentIdx + assertSize(t, &i, ExtentEntrySize) + var e Extent + assertSize(t, &e, ExtentEntrySize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go index 88ae913f5..ef25040a9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go @@ -16,6 +16,7 @@ package disklayout import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) @@ -38,6 +39,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes. type Inode interface { + marshal.Marshallable + // Mode returns the linux file mode which is majorly used to extract // information like: // - File permissions (read/write/execute by user/group/others). diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go index 8f9f574ce..a4503f5cf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go @@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time" // are used to provide nanoscond precision. Hence, these timestamps will now // overflow in May 2446. // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. +// +// +marshal type InodeNew struct { InodeOld diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go index db25b11b6..e6b28babf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go @@ -30,6 +30,8 @@ const ( // // All fields representing time are in seconds since the epoch. Which means that // they will overflow in January 2038. +// +// +marshal type InodeOld struct { ModeRaw uint16 UIDLo uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go index dd03ee50e..90744e956 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go @@ -24,10 +24,12 @@ import ( // TestInodeSize tests that the inode structs are of the correct size. func TestInodeSize(t *testing.T) { - assertSize(t, InodeOld{}, OldInodeSize) + var iOld InodeOld + assertSize(t, &iOld, OldInodeSize) // This was updated from 156 bytes to 160 bytes in Oct 2015. - assertSize(t, InodeNew{}, 160) + var iNew InodeNew + assertSize(t, &iNew, 160) } // TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go index 8bb327006..70948ebe9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + const ( // SbOffset is the absolute offset at which the superblock is placed. SbOffset = 1024 @@ -38,6 +42,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block. type SuperBlock interface { + marshal.Marshallable + // InodesCount returns the total number of inodes in this filesystem. InodesCount() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go index 53e515fd3..4dc6080fb 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go @@ -17,6 +17,8 @@ package disklayout // SuperBlock32Bit implements SuperBlock and represents the 32-bit version of // the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if // RevLevel = DynamicRev and 64-bit feature is disabled. +// +// +marshal type SuperBlock32Bit struct { // We embed the old superblock struct here because the 32-bit version is just // an extension of the old version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go index 7c1053fb4..2c9039327 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go @@ -19,6 +19,8 @@ package disklayout // 1024 bytes (smallest possible block size) and hence the superblock always // fits in no more than one data block. Should only be used when the 64-bit // feature is set. +// +// +marshal type SuperBlock64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go index 9221e0251..e4709f23c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go @@ -16,6 +16,8 @@ package disklayout // SuperBlockOld implements SuperBlock and represents the old version of the // superblock struct. Should be used only if RevLevel = OldRev. +// +// +marshal type SuperBlockOld struct { InodesCountRaw uint32 BlocksCountLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go index 463b5ba21..b734b6987 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go @@ -21,7 +21,10 @@ import ( // TestSuperBlockSize tests that the superblock structs are of the correct // size. func TestSuperBlockSize(t *testing.T) { - assertSize(t, SuperBlockOld{}, 84) - assertSize(t, SuperBlock32Bit{}, 336) - assertSize(t, SuperBlock64Bit{}, 1024) + var sbOld SuperBlockOld + assertSize(t, &sbOld, 84) + var sb32 SuperBlock32Bit + assertSize(t, &sb32, 336) + var sb64 SuperBlock64Bit + assertSize(t, &sb64, 1024) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go index 9c63f04c0..a4bc08411 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go +++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go @@ -18,13 +18,13 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" ) -func assertSize(t *testing.T, v interface{}, want uintptr) { +func assertSize(t *testing.T, v marshal.Marshallable, want int) { t.Helper() - if got := binary.Size(v); got != want { + if got := v.SizeBytes(); got != want { t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got) } } diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go index 08ffc2834..aca258d40 100644 --- a/pkg/sentry/fsimpl/ext/ext.go +++ b/pkg/sentry/fsimpl/ext/ext.go @@ -34,6 +34,8 @@ import ( const Name = "ext" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Compiles only if FilesystemType implements vfs.FilesystemType. diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 2dbaee287..0989558cd 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -71,7 +71,11 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index c36225a7c..778460107 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -18,12 +18,13 @@ import ( "io" "sort" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // extentFile is a type of regular file which uses extents to store file data. +// +// +stateify savable type extentFile struct { regFile regularFile @@ -58,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) { func (f *extentFile) buildExtTree() error { rootNodeData := f.regFile.inode.diskInode.Data() - binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) + f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize]) // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. if f.root.Header.NumEntries > 4 { @@ -77,7 +78,7 @@ func (f *extentFile) buildExtTree() error { // Internal node. curEntry = &disklayout.ExtentIdx{} } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) + curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize]) f.root.Entries[i].Entry = curEntry } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index cd10d46ee..985f76ac0 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] // writeTree writes the tree represented by `root` to the inode and disk. It // also writes random file data on disk. func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { - rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) + rootData := in.diskInode.Data() + root.Header.MarshalBytes(rootData) + off := root.Header.SizeBytes() for _, ep := range root.Entries { - rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(rootData[off:]) + off += ep.Entry.SizeBytes() } - copy(in.diskInode.Data(), rootData) - var fileData []byte for _, ep := range root.Entries { if root.Header.Height == 0 { @@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl // writeTreeToDisk is the recursive step for writeTree which writes the tree // on the disk only. Also writes random file data on disk. func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { - nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) + nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:] + curNode.Node.Header.MarshalBytes(nodeData) + off := curNode.Node.Header.SizeBytes() for _, ep := range curNode.Node.Entries { - nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(nodeData[off:]) + off += ep.Entry.SizeBytes() } - copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) - var fileData []byte for _, ep := range curNode.Node.Entries { if curNode.Node.Header.Height == 0 { diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 8565d1a66..917f1873d 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -38,11 +38,13 @@ var ( ) // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem // mu serializes changes to the Dentry tree. - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` // dev represents the underlying fs device. It does not require protection // because io.ReaderAt permits concurrent read calls to it. It translates to @@ -490,7 +492,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { _, inode, err := fs.walk(ctx, rp, false) if err != nil { @@ -504,8 +506,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { _, _, err := fs.walk(ctx, rp, false) if err != nil { return nil, err @@ -513,8 +515,8 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { _, _, err := fs.walk(ctx, rp, false) if err != nil { return "", err @@ -522,8 +524,8 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { _, _, err := fs.walk(ctx, rp, false) if err != nil { return err @@ -531,8 +533,8 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { _, _, err := fs.walk(ctx, rp, false) if err != nil { return err diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 30636cf66..9009ba3c7 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -37,6 +37,8 @@ import ( // |-- regular-- // |-- extent file // |-- block map file +// +// +stateify savable type inode struct { // refs is a reference count. refs is accessed using atomic memory operations. refs int64 diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index e73e740d6..4a5539b37 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -31,6 +31,8 @@ import ( // regularFile represents a regular file's inode. This too follows the // inheritance pattern prevelant in the vfs layer described in // pkg/sentry/vfs/README.md. +// +// +stateify savable type regularFile struct { inode inode @@ -67,6 +69,8 @@ func (in *inode) isRegular() bool { // directoryFD represents a directory file description. It implements // vfs.FileDescriptionImpl. +// +// +stateify savable type regularFileFD struct { fileDescription vfs.LockFD @@ -75,7 +79,7 @@ type regularFileFD struct { off int64 // offMu serializes operations that may mutate off. - offMu sync.Mutex + offMu sync.Mutex `state:"nosave"` } // Release implements vfs.FileDescriptionImpl.Release. diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index 2fd0d1fa8..5e2bcc837 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -23,6 +23,8 @@ import ( ) // symlink represents a symlink inode. +// +// +stateify savable type symlink struct { inode inode target string // immutable @@ -61,9 +63,11 @@ func (in *inode) isSymlink() bool { return ok } -// symlinkFD represents a symlink file description and implements implements +// symlinkFD represents a symlink file description and implements // vfs.FileDescriptionImpl. which may only be used if open options contains // O_PATH. For this reason most of the functions return EBADF. +// +// +stateify savable type symlinkFD struct { fileDescription vfs.NoLockFD diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go index d8b728f8c..58ef7b9b8 100644 --- a/pkg/sentry/fsimpl/ext/utils.go +++ b/pkg/sentry/fsimpl/ext/utils.go @@ -17,21 +17,21 @@ package ext import ( "io" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // readFromDisk performs a binary read from disk into the given struct from // the absolute offset provided. -func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { - n := binary.Size(v) +func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error { + n := v.SizeBytes() buf := make([]byte, n) if read, _ := dev.ReadAt(buf, abOff); read < int(n) { return syserror.EIO } - binary.Unmarshal(buf, binary.LittleEndian, v) + v.UnmarshalBytes(buf) return nil } diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 53a4f3012..045d7ab08 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -30,19 +30,26 @@ go_library( name = "fuse", srcs = [ "connection.go", + "connection_control.go", "dev.go", + "directory.go", + "file.go", "fusefs.go", - "init.go", "inode_refs.go", + "read_write.go", "register.go", + "regular_file.go", "request_list.go", + "request_response.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/log", + "//pkg/marshal", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", @@ -52,7 +59,6 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -60,10 +66,15 @@ go_library( go_test( name = "fuse_test", size = "small", - srcs = ["dev_test.go"], + srcs = [ + "connection_test.go", + "dev_test.go", + "utils_test.go", + ], library = ":fuse", deps = [ "//pkg/abi/linux", + "//pkg/marshal", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -71,6 +82,5 @@ go_test( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go index 6df2728ab..8ccda1264 100644 --- a/pkg/sentry/fsimpl/fuse/connection.go +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -15,31 +15,17 @@ package fuse import ( - "errors" - "fmt" "sync" - "sync/atomic" - "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) -// maxActiveRequestsDefault is the default setting controlling the upper bound -// on the number of active requests at any given time. -const maxActiveRequestsDefault = 10000 - -// Ordinary requests have even IDs, while interrupts IDs are odd. -// Used to increment the unique ID for each FUSE request. -var reqIDStep uint64 = 2 - const ( // fuseDefaultMaxBackground is the default value for MaxBackground. fuseDefaultMaxBackground = 12 @@ -52,43 +38,39 @@ const ( fuseDefaultMaxPagesPerReq = 32 ) -// Request represents a FUSE operation request that hasn't been sent to the -// server yet. +// connection is the struct by which the sentry communicates with the FUSE server daemon. // -// +stateify savable -type Request struct { - requestEntry - - id linux.FUSEOpID - hdr *linux.FUSEHeaderIn - data []byte -} - -// Response represents an actual response from the server, including the -// response payload. +// Lock order: +// - conn.fd.mu +// - conn.mu +// - conn.asyncMu // // +stateify savable -type Response struct { - opcode linux.FUSEOpcode - hdr linux.FUSEHeaderOut - data []byte -} - -// connection is the struct by which the sentry communicates with the FUSE server daemon. type connection struct { fd *DeviceFD + // mu protects access to struct memebers. + mu sync.Mutex `state:"nosave"` + + // attributeVersion is the version of connection's attributes. + attributeVersion uint64 + + // We target FUSE 7.23. // The following FUSE_INIT flags are currently unsupported by this implementation: - // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC) // - FUSE_EXPORT_SUPPORT - // - FUSE_HANDLE_KILLPRIV // - FUSE_POSIX_LOCKS: requires POSIX locks // - FUSE_FLOCK_LOCKS: requires POSIX locks // - FUSE_AUTO_INVAL_DATA: requires page caching eviction - // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation // - FUSE_ASYNC_DIO - // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler + // - FUSE_PARALLEL_DIROPS (7.25) + // - FUSE_HANDLE_KILLPRIV (7.26) + // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler (7.26) + // - FUSE_ABORT_ERROR (7.27) + // - FUSE_CACHE_SYMLINKS (7.28) + // - FUSE_NO_OPENDIR_SUPPORT (7.29) + // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction (7.30) + // - FUSE_MAP_ALIGNMENT (7.31) // initialized after receiving FUSE_INIT reply. // Until it's set, suspend sending FUSE requests. @@ -96,11 +78,7 @@ type connection struct { initialized int32 // initializedChan is used to block requests before initialization. - initializedChan chan struct{} - - // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground). - // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block. - blocked bool + initializedChan chan struct{} `state:".(bool)"` // connected (connection established) when a new FUSE file system is created. // Set to false when: @@ -109,48 +87,55 @@ type connection struct { // device release. connected bool - // aborted via sysfs. - // TODO(gvisor.dev/issue/3185): abort all queued requests. - aborted bool - // connInitError if FUSE_INIT encountered error (major version mismatch). // Only set in INIT. connInitError bool // connInitSuccess if FUSE_INIT is successful. // Only set in INIT. - // Used for destory. + // Used for destory (not yet implemented). connInitSuccess bool - // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress. - - // NumberBackground is the number of requests in the background. - numBackground uint16 + // aborted via sysfs, and will send ECONNABORTED to read after disconnection (instead of ENODEV). + // Set only if abortErr is true and via fuse control fs (not yet implemented). + // TODO(gvisor.dev/issue/3525): set this to true when user aborts. + aborted bool - // congestionThreshold for NumBackground. - // Negotiated in FUSE_INIT. - congestionThreshold uint16 + // numWating is the number of requests waiting to be + // sent to FUSE device or being processed by FUSE daemon. + numWaiting uint32 - // maxBackground is the maximum number of NumBackground. - // Block connection when it is reached. - // Negotiated in FUSE_INIT. - maxBackground uint16 + // Terminology note: + // + // - `asyncNumMax` is the `MaxBackground` in the FUSE_INIT_IN struct. + // + // - `asyncCongestionThreshold` is the `CongestionThreshold` in the FUSE_INIT_IN struct. + // + // We call the "background" requests in unix term as async requests. + // The "async requests" in unix term is our async requests that expect a reply, + // i.e. `!request.noReply` - // numActiveBackground is the number of requests in background and has being marked as active. - numActiveBackground uint16 + // asyncMu protects the async request fields. + asyncMu sync.Mutex `state:"nosave"` - // numWating is the number of requests waiting for completion. - numWaiting uint32 + // asyncNum is the number of async requests. + // Protected by asyncMu. + asyncNum uint16 - // TODO(gvisor.dev/issue/3185): BgQueue - // some queue for background queued requests. + // asyncCongestionThreshold the number of async requests. + // Negotiated in FUSE_INIT as "CongestionThreshold". + // TODO(gvisor.dev/issue/3529): add congestion control. + // Protected by asyncMu. + asyncCongestionThreshold uint16 - // bgLock protects: - // MaxBackground, CongestionThreshold, NumBackground, - // NumActiveBackground, BgQueue, Blocked. - bgLock sync.Mutex + // asyncNumMax is the maximum number of asyncNum. + // Connection blocks the async requests when it is reached. + // Negotiated in FUSE_INIT as "MaxBackground". + // Protected by asyncMu. + asyncNumMax uint16 // maxRead is the maximum size of a read buffer in in bytes. + // Initialized from a fuse fs parameter. maxRead uint32 // maxWrite is the maximum size of a write buffer in bytes. @@ -165,23 +150,20 @@ type connection struct { // Negotiated and only set in INIT. minor uint32 - // asyncRead if read pages asynchronously. + // atomicOTrunc is true when FUSE does not send a separate SETATTR request + // before open with O_TRUNC flag. // Negotiated and only set in INIT. - asyncRead bool + atomicOTrunc bool - // abortErr is true if kernel need to return an unique read error after abort. + // asyncRead if read pages asynchronously. // Negotiated and only set in INIT. - abortErr bool + asyncRead bool // writebackCache is true for write-back cache policy, // false for write-through policy. // Negotiated and only set in INIT. writebackCache bool - // cacheSymlinks if filesystem needs to cache READLINK responses in page cache. - // Negotiated and only set in INIT. - cacheSymlinks bool - // bigWrites if doing multi-page cached writes. // Negotiated and only set in INIT. bigWrites bool @@ -189,116 +171,86 @@ type connection struct { // dontMask if filestestem does not apply umask to creation modes. // Negotiated in INIT. dontMask bool + + // noOpen if FUSE server doesn't support open operation. + // This flag only influence performance, not correctness of the program. + noOpen bool +} + +func (conn *connection) saveInitializedChan() bool { + select { + case <-conn.initializedChan: + return true // Closed. + default: + return false // Not closed. + } +} + +func (conn *connection) loadInitializedChan(closed bool) { + conn.initializedChan = make(chan struct{}, 1) + if closed { + close(conn.initializedChan) + } } // newFUSEConnection creates a FUSE connection to fd. -func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) { +func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) { // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to // mount a FUSE filesystem. fuseFD := fd.Impl().(*DeviceFD) - fuseFD.mounted = true // Create the writeBuf for the header to be stored in. hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) fuseFD.writeBuf = make([]byte, hdrLen) fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) - fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests) + fuseFD.fullQueueCh = make(chan struct{}, opts.maxActiveRequests) fuseFD.writeCursor = 0 return &connection{ - fd: fuseFD, - maxBackground: fuseDefaultMaxBackground, - congestionThreshold: fuseDefaultCongestionThreshold, - maxPages: fuseDefaultMaxPagesPerReq, - initializedChan: make(chan struct{}), - connected: true, - }, nil -} - -// SetInitialized atomically sets the connection as initialized. -func (conn *connection) SetInitialized() { - // Unblock the requests sent before INIT. - close(conn.initializedChan) - - // Close the channel first to avoid the non-atomic situation - // where conn.initialized is true but there are - // tasks being blocked on the channel. - // And it prevents the newer tasks from gaining - // unnecessary higher chance to be issued before the blocked one. - - atomic.StoreInt32(&(conn.initialized), int32(1)) -} - -// IsInitialized atomically check if the connection is initialized. -// pairs with SetInitialized(). -func (conn *connection) Initialized() bool { - return atomic.LoadInt32(&(conn.initialized)) != 0 -} - -// NewRequest creates a new request that can be sent to the FUSE server. -func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { - conn.fd.mu.Lock() - defer conn.fd.mu.Unlock() - conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) - - hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() - hdr := linux.FUSEHeaderIn{ - Len: uint32(hdrLen + payload.SizeBytes()), - Opcode: opcode, - Unique: conn.fd.nextOpID, - NodeID: ino, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - PID: pid, - } - - buf := make([]byte, hdr.Len) - hdr.MarshalUnsafe(buf[:hdrLen]) - payload.MarshalUnsafe(buf[hdrLen:]) - - return &Request{ - id: hdr.Unique, - hdr: &hdr, - data: buf, + fd: fuseFD, + asyncNumMax: fuseDefaultMaxBackground, + asyncCongestionThreshold: fuseDefaultCongestionThreshold, + maxRead: opts.maxRead, + maxPages: fuseDefaultMaxPagesPerReq, + initializedChan: make(chan struct{}), + connected: true, }, nil } -// Call makes a request to the server and blocks the invoking task until a -// server responds with a response. Task should never be nil. -// Requests will not be sent before the connection is initialized. -// For async tasks, use CallAsync(). -func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { - // Block requests sent before connection is initalized. - if !conn.Initialized() { - if err := t.Block(conn.initializedChan); err != nil { - return nil, err - } - } - - return conn.call(t, r) +// CallAsync makes an async (aka background) request. +// It's a simple wrapper around Call(). +func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { + r.async = true + _, err := conn.Call(t, r) + return err } -// CallAsync makes an async (aka background) request. -// Those requests either do not expect a response (e.g. release) or -// the response should be handled by others (e.g. init). -// Return immediately unless the connection is blocked (before initialization). -// Async call example: init, release, forget, aio, interrupt. +// Call makes a request to the server. +// Block before the connection is initialized. // When the Request is FUSE_INIT, it will not be blocked before initialization. -func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { +// Task should never be nil. +// +// For a sync request, it blocks the invoking task until +// a server responds with a response. +// +// For an async request (that do not expect a response immediately), +// it returns directly unless being blocked either before initialization +// or when there are too many async requests ongoing. +// +// Example for async request: +// init, readahead, write, async read/write, fuse_notify_reply, +// non-sync release, interrupt, forget. +// +// The forget request does not have a reply, +// as documented in include/uapi/linux/fuse.h:FUSE_FORGET. +func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { // Block requests sent before connection is initalized. if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT { if err := t.Block(conn.initializedChan); err != nil { - return err + return nil, err } } - // This should be the only place that invokes call() with a nil task. - _, err := conn.call(nil, r) - return err -} - -// call makes a call without blocking checks. -func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) { if !conn.connected { return nil, syserror.ENOTCONN } @@ -315,31 +267,6 @@ func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) { return fut.resolve(t) } -// Error returns the error of the FUSE call. -func (r *Response) Error() error { - errno := r.hdr.Error - if errno >= 0 { - return nil - } - - sysErrNo := syscall.Errno(-errno) - return error(sysErrNo) -} - -// UnmarshalPayload unmarshals the response data into m. -func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { - hdrLen := r.hdr.SizeBytes() - haveDataLen := r.hdr.Len - uint32(hdrLen) - wantDataLen := uint32(m.SizeBytes()) - - if haveDataLen < wantDataLen { - return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) - } - - m.UnmarshalUnsafe(r.data[hdrLen:]) - return nil -} - // callFuture makes a request to the server and returns a future response. // Call resolve() when the response needs to be fulfilled. func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { @@ -358,11 +285,6 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, // if there are always too many ongoing requests all the time. The // supported maxActiveRequests setting should be really high to avoid this. for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { - if t == nil { - // Since there is no task that is waiting. We must error out. - return nil, errors.New("FUSE request queue full") - } - log.Infof("Blocking request %v from being queued. Too many active requests: %v", r.id, conn.fd.numActiveRequests) conn.fd.mu.Unlock() @@ -378,9 +300,19 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, // callFutureLocked makes a request to the server and returns a future response. func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + // Check connected again holding conn.mu. + conn.mu.Lock() + if !conn.connected { + conn.mu.Unlock() + // we checked connected before, + // this must be due to aborted connection. + return nil, syserror.ECONNABORTED + } + conn.mu.Unlock() + conn.fd.queue.PushBack(r) - conn.fd.numActiveRequests += 1 - fut := newFutureResponse(r.hdr.Opcode) + conn.fd.numActiveRequests++ + fut := newFutureResponse(r) conn.fd.completions[r.id] = fut // Signal the readers that there is something to read. @@ -388,50 +320,3 @@ func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureRes return fut, nil } - -// futureResponse represents an in-flight request, that may or may not have -// completed yet. Convert it to a resolved Response by calling Resolve, but note -// that this may block. -// -// +stateify savable -type futureResponse struct { - opcode linux.FUSEOpcode - ch chan struct{} - hdr *linux.FUSEHeaderOut - data []byte -} - -// newFutureResponse creates a future response to a FUSE request. -func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse { - return &futureResponse{ - opcode: opcode, - ch: make(chan struct{}), - } -} - -// resolve blocks the task until the server responds to its corresponding request, -// then returns a resolved response. -func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { - // If there is no Task associated with this request - then we don't try to resolve - // the response. Instead, the task writing the response (proxy to the server) will - // process the response on our behalf. - if t == nil { - log.Infof("fuse.Response.resolve: Not waiting on a response from server.") - return nil, nil - } - - if err := t.Block(f.ch); err != nil { - return nil, err - } - - return f.getResponse(), nil -} - -// getResponse creates a Response from the data the futureResponse has. -func (f *futureResponse) getResponse() *Response { - return &Response{ - opcode: f.opcode, - hdr: *f.hdr, - data: f.data, - } -} diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/connection_control.go index 779c2bd3f..bfde78559 100644 --- a/pkg/sentry/fsimpl/fuse/init.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -15,7 +15,11 @@ package fuse import ( + "sync/atomic" + "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -29,9 +33,10 @@ const ( // Follow the same behavior as unix fuse implementation. fuseMaxTimeGranNs = 1000000000 - // Minimum value for MaxWrite. + // Minimum value for MaxWrite and MaxRead. // Follow the same behavior as unix fuse implementation. fuseMinMaxWrite = 4096 + fuseMinMaxRead = 4096 // Temporary default value for max readahead, 128kb. fuseDefaultMaxReadahead = 131072 @@ -49,6 +54,26 @@ var ( MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold ) +// SetInitialized atomically sets the connection as initialized. +func (conn *connection) SetInitialized() { + // Unblock the requests sent before INIT. + close(conn.initializedChan) + + // Close the channel first to avoid the non-atomic situation + // where conn.initialized is true but there are + // tasks being blocked on the channel. + // And it prevents the newer tasks from gaining + // unnecessary higher chance to be issued before the blocked one. + + atomic.StoreInt32(&(conn.initialized), int32(1)) +} + +// IsInitialized atomically check if the connection is initialized. +// pairs with SetInitialized(). +func (conn *connection) Initialized() bool { + return atomic.LoadInt32(&(conn.initialized)) != 0 +} + // InitSend sends a FUSE_INIT request. func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { in := linux.FUSEInitIn{ @@ -70,29 +95,31 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { } // InitRecv receives a FUSE_INIT reply and process it. +// +// Preconditions: conn.asyncMu must not be held if minor verion is newer than 13. func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error { if err := res.Error(); err != nil { return err } - var out linux.FUSEInitOut - if err := res.UnmarshalPayload(&out); err != nil { + initRes := fuseInitRes{initLen: res.DataLen()} + if err := res.UnmarshalPayload(&initRes); err != nil { return err } - return conn.initProcessReply(&out, hasSysAdminCap) + return conn.initProcessReply(&initRes.initOut, hasSysAdminCap) } // Process the FUSE_INIT reply from the FUSE server. +// It tries to acquire the conn.asyncMu lock if minor version is newer than 13. func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error { + // No matter error or not, always set initialzied. + // to unblock the blocked requests. + defer conn.SetInitialized() + // No support for old major fuse versions. if out.Major != linux.FUSE_KERNEL_VERSION { conn.connInitError = true - - // Set the connection as initialized and unblock the blocked requests - // (i.e. return error for them). - conn.SetInitialized() - return nil } @@ -100,29 +127,14 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap conn.connInitSuccess = true conn.minor = out.Minor - // No support for limits before minor version 13. - if out.Minor >= 13 { - conn.bgLock.Lock() - - if out.MaxBackground > 0 { - conn.maxBackground = out.MaxBackground - - if !hasSysAdminCap && - conn.maxBackground > MaxUserBackgroundRequest { - conn.maxBackground = MaxUserBackgroundRequest - } - } - - if out.CongestionThreshold > 0 { - conn.congestionThreshold = out.CongestionThreshold - - if !hasSysAdminCap && - conn.congestionThreshold > MaxUserCongestionThreshold { - conn.congestionThreshold = MaxUserCongestionThreshold - } - } - - conn.bgLock.Unlock() + // No support for negotiating MaxWrite before minor version 5. + if out.Minor >= 5 { + conn.maxWrite = out.MaxWrite + } else { + conn.maxWrite = fuseMinMaxWrite + } + if conn.maxWrite < fuseMinMaxWrite { + conn.maxWrite = fuseMinMaxWrite } // No support for the following flags before minor version 6. @@ -131,8 +143,6 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0 conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0 conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0 - conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0 - conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0 // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs). @@ -148,19 +158,90 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap } } - // No support for negotiating MaxWrite before minor version 5. - if out.Minor >= 5 { - conn.maxWrite = out.MaxWrite - } else { - conn.maxWrite = fuseMinMaxWrite + // No support for limits before minor version 13. + if out.Minor >= 13 { + conn.asyncMu.Lock() + + if out.MaxBackground > 0 { + conn.asyncNumMax = out.MaxBackground + + if !hasSysAdminCap && + conn.asyncNumMax > MaxUserBackgroundRequest { + conn.asyncNumMax = MaxUserBackgroundRequest + } + } + + if out.CongestionThreshold > 0 { + conn.asyncCongestionThreshold = out.CongestionThreshold + + if !hasSysAdminCap && + conn.asyncCongestionThreshold > MaxUserCongestionThreshold { + conn.asyncCongestionThreshold = MaxUserCongestionThreshold + } + } + + conn.asyncMu.Unlock() } - if conn.maxWrite < fuseMinMaxWrite { - conn.maxWrite = fuseMinMaxWrite + + return nil +} + +// Abort this FUSE connection. +// It tries to acquire conn.fd.mu, conn.lock, conn.bgLock in order. +// All possible requests waiting or blocking will be aborted. +// +// Preconditions: conn.fd.mu is locked. +func (conn *connection) Abort(ctx context.Context) { + conn.mu.Lock() + conn.asyncMu.Lock() + + if !conn.connected { + conn.asyncMu.Unlock() + conn.mu.Unlock() + conn.fd.mu.Unlock() + return } - // Set connection as initialized and unblock the requests - // issued before init. - conn.SetInitialized() + conn.connected = false - return nil + // Empty the `fd.queue` that holds the requests + // not yet read by the FUSE daemon yet. + // These are a subset of the requests in `fuse.completion` map. + for !conn.fd.queue.Empty() { + req := conn.fd.queue.Front() + conn.fd.queue.Remove(req) + } + + var terminate []linux.FUSEOpID + + // 2. Collect the requests have not been sent to FUSE daemon, + // or have not received a reply. + for unique := range conn.fd.completions { + terminate = append(terminate, unique) + } + + // Release locks to avoid deadlock. + conn.asyncMu.Unlock() + conn.mu.Unlock() + + // 1. The requets blocked before initialization. + // Will reach call() `connected` check and return. + if !conn.Initialized() { + conn.SetInitialized() + } + + // 2. Terminate the requests collected above. + // Set ECONNABORTED error. + // sendError() will remove them from `fd.completion` map. + // Will enter the path of a normally received error. + for _, toTerminate := range terminate { + conn.fd.sendError(ctx, -int32(syscall.ECONNABORTED), toTerminate) + } + + // 3. The requests not yet written to FUSE device. + // Early terminate. + // Will reach callFutureLocked() `connected` check and return. + close(conn.fd.fullQueueCh) + + // TODO(gvisor.dev/issue/3528): Forget all pending forget reqs. } diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go new file mode 100644 index 000000000..91d16c1cf --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "math/rand" + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" +) + +// TestConnectionInitBlock tests if initialization +// correctly blocks and unblocks the connection. +// Since it's unfeasible to test kernelTask.Block() in unit test, +// the code in Call() are not tested here. +func TestConnectionInitBlock(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + + conn, _, err := newTestConnection(s, k, maxActiveRequestsDefault) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + select { + case <-conn.initializedChan: + t.Fatalf("initializedChan should be blocking before SetInitialized") + default: + } + + conn.SetInitialized() + + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } +} + +func TestConnectionAbort(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + task := kernel.TaskFromContext(s.Ctx) + + const numRequests uint64 = 256 + + conn, _, err := newTestConnection(s, k, numRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + testObj := &testPayload{ + data: rand.Uint32(), + } + + var futNormal []*futureResponse + + for i := 0; i < int(numRequests); i++ { + req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + fut, err := conn.callFutureLocked(task, req) + if err != nil { + t.Fatalf("callFutureLocked failed: %v", err) + } + futNormal = append(futNormal, fut) + } + + conn.Abort(s.Ctx) + + // Abort should unblock the initialization channel. + // Note: no test requests are actually blocked on `conn.initializedChan`. + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } + + // Abort will return ECONNABORTED error to unblocked requests. + for _, fut := range futNormal { + if fut.getResponse().hdr.Error != -int32(syscall.ECONNABORTED) { + t.Fatalf("Incorrect error code received for aborted connection: %v", fut.getResponse().hdr.Error) + } + } + + // After abort, Call() should return directly with ENOTCONN. + req, err := conn.NewRequest(creds, 0, 0, 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + _, err = conn.Call(task, req) + if err != syserror.ENOTCONN { + t.Fatalf("Incorrect error code received for Call() after connection aborted") + } + +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index e522ff9a0..1b86a4b4c 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -32,6 +31,8 @@ import ( const fuseDevMinor = 229 // fuseDevice implements vfs.Device for /dev/fuse. +// +// +stateify savable type fuseDevice struct{} // Open implements vfs.Device.Open. @@ -50,15 +51,14 @@ func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, op } // DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse. +// +// +stateify savable type DeviceFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl vfs.NoLockFD - // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. - mounted bool - // nextOpID is used to create new requests. nextOpID linux.FUSEOpID @@ -83,7 +83,7 @@ type DeviceFD struct { writeCursorFR *futureResponse // mu protects all the queues, maps, buffers and cursors and nextOpID. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // waitQueue is used to notify interested parties when the device becomes // readable or writable. @@ -92,21 +92,36 @@ type DeviceFD struct { // fullQueueCh is a channel used to synchronize the readers with the writers. // Writers (inbound requests to the filesystem) block if there are too many // unprocessed in-flight requests. - fullQueueCh chan struct{} + fullQueueCh chan struct{} `state:".(int)"` // fs is the FUSE filesystem that this FD is being used for. fs *filesystem } +func (fd *DeviceFD) saveFullQueueCh() int { + return cap(fd.fullQueueCh) +} + +func (fd *DeviceFD) loadFullQueueCh(capacity int) { + fd.fullQueueCh = make(chan struct{}, capacity) +} + // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DeviceFD) Release(context.Context) { - fd.fs.conn.connected = false +func (fd *DeviceFD) Release(ctx context.Context) { + if fd.fs != nil { + fd.fs.conn.mu.Lock() + fd.fs.conn.connected = false + fd.fs.conn.mu.Unlock() + + fd.fs.VFSFilesystem().DecRef(ctx) + fd.fs = nil + } } // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -116,10 +131,16 @@ func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset in // Read implements vfs.FileDescriptionImpl.Read. func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs. + return 0, syserror.ENODEV + } + // We require that any Read done on this filesystem have a sane minimum // read buffer. It must have the capacity for the fixed parts of any request // header (Linux uses the request header and the FUSEWriteIn header for this @@ -143,58 +164,82 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R } // readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +// +// Preconditions: dst is large enough for any reasonable request. func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - if fd.queue.Empty() { - return 0, syserror.ErrWouldBlock - } + var req *Request - var readCursor uint32 - var bytesRead int64 - for { - req := fd.queue.Front() - if dst.NumBytes() < int64(req.hdr.Len) { - // The request is too large. Cannot process it. All requests must be smaller than the - // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT - // handshake. - errno := -int32(syscall.EIO) - if req.hdr.Opcode == linux.FUSE_SETXATTR { - errno = -int32(syscall.E2BIG) - } + // Find the first valid request. + // For the normal case this loop only execute once. + for !fd.queue.Empty() { + req = fd.queue.Front() - // Return the error to the calling task. - if err := fd.sendError(ctx, errno, req); err != nil { - return 0, err - } + if int64(req.hdr.Len)+int64(len(req.payload)) <= dst.NumBytes() { + break + } - // We're done with this request. - fd.queue.Remove(req) + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } - // Restart the read as this request was invalid. - log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.") - return fd.readLocked(ctx, dst, opts) + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req.hdr.Unique); err != nil { + return 0, err } - n, err := dst.CopyOut(ctx, req.data[readCursor:]) + // We're done with this request. + fd.queue.Remove(req) + req = nil + } + + if req == nil { + return 0, syserror.ErrWouldBlock + } + + // We already checked the size: dst must be able to fit the whole request. + // Now we write the marshalled header, the payload, + // and the potential additional payload + // to the user memory IOSequence. + + n, err := dst.CopyOut(ctx, req.data) + if err != nil { + return 0, err + } + if n != len(req.data) { + return 0, syserror.EIO + } + + if req.hdr.Opcode == linux.FUSE_WRITE { + written, err := dst.DropFirst(n).CopyOut(ctx, req.payload) if err != nil { return 0, err } - readCursor += uint32(n) - bytesRead += int64(n) - - if readCursor >= req.hdr.Len { - // Fully done with this req, remove it from the queue. - fd.queue.Remove(req) - break + if written != len(req.payload) { + return 0, syserror.EIO } + n += int(written) } - return bytesRead, nil + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + + // Remove noReply ones from map of requests expecting a reply. + if req.noReply { + fd.numActiveRequests -= 1 + delete(fd.completions, req.hdr.Unique) + } + + return int64(n), nil } // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -211,10 +256,15 @@ func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs. // writeLocked implements writing to the fuse device while locked with DeviceFD.mu. func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + return 0, syserror.ENODEV + } + var cn, n int64 hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) @@ -276,7 +326,8 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt fut, ok := fd.completions[hdr.Unique] if !ok { - // Server sent us a response for a request we never sent? + // Server sent us a response for a request we never sent, + // or for which we already received a reply (e.g. aborted), an unlikely event. return 0, syserror.EINVAL } @@ -307,8 +358,23 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt // Readiness implements vfs.FileDescriptionImpl.Readiness. func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readinessLocked(mask) +} + +// readinessLocked implements checking the readiness of the fuse device while +// locked with DeviceFD.mu. +func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { var ready waiter.EventMask - ready |= waiter.EventOut // FD is always writable + + if fd.fs.umounted { + ready |= waiter.EventErr + return ready & mask + } + + // FD is always writable. + ready |= waiter.EventOut if !fd.queue.Empty() { // Have reqs available, FD is readable. ready |= waiter.EventIn @@ -330,7 +396,7 @@ func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -338,59 +404,59 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64 } // sendResponse sends a response to the waiting task (if any). +// +// Preconditions: fd.mu must be held. func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { - // See if the running task need to perform some action before returning. - // Since we just finished writing the future, we can be sure that - // getResponse generates a populated response. - if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil { - return err - } + // Signal the task waiting on a response if any. + defer close(fut.ch) // Signal that the queue is no longer full. select { case fd.fullQueueCh <- struct{}{}: default: } - fd.numActiveRequests -= 1 + fd.numActiveRequests-- + + if fut.async { + return fd.asyncCallBack(ctx, fut.getResponse()) + } - // Signal the task waiting on a response. - close(fut.ch) return nil } -// sendError sends an error response to the waiting task (if any). -func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error { +// sendError sends an error response to the waiting task (if any) by calling sendResponse(). +// +// Preconditions: fd.mu must be held. +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUSEOpID) error { // Return the error to the calling task. outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) respHdr := linux.FUSEHeaderOut{ Len: outHdrLen, Error: errno, - Unique: req.hdr.Unique, + Unique: unique, } fut, ok := fd.completions[respHdr.Unique] if !ok { - // Server sent us a response for a request we never sent? + // A response for a request we never sent, + // or for which we already received a reply (e.g. aborted). return syserror.EINVAL } delete(fd.completions, respHdr.Unique) fut.hdr = &respHdr - if err := fd.sendResponse(ctx, fut); err != nil { - return err - } - - return nil + return fd.sendResponse(ctx, fut) } -// noReceiverAction has the calling kernel.Task do some action if its known that no -// receiver is going to be waiting on the future channel. This is to be used by: -// FUSE_INIT. -func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error { - if r.opcode == linux.FUSE_INIT { +// asyncCallBack executes pre-defined callback function for async requests. +// Currently used by: FUSE_INIT. +func (fd *DeviceFD) asyncCallBack(ctx context.Context, r *Response) error { + switch r.opcode { + case linux.FUSE_INIT: creds := auth.CredentialsFromContext(ctx) rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace() return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs)) + // TODO(gvisor.dev/issue/3247): support async read: correctly process the response. } return nil diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 1ffe7ccd2..5986133e9 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -16,7 +16,6 @@ package fuse import ( "fmt" - "io" "math/rand" "testing" @@ -28,17 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // echoTestOpcode is the Opcode used during testing. The server used in tests // will simply echo the payload back with the appropriate headers. const echoTestOpcode linux.FUSEOpcode = 1000 -type testPayload struct { - data uint32 -} - // TestFUSECommunication tests that the communication layer between the Sentry and the // FUSE server daemon works as expected. func TestFUSECommunication(t *testing.T) { @@ -327,102 +321,3 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F } } } - -func setup(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Error creating kernel: %v", err) - } - - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - - k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserList: true, - AllowUserMount: true, - }) - - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) - } - - return testutil.NewSystem(ctx, t, k.VFS(), mntns) -} - -// newTestConnection creates a fuse connection that the sentry can communicate with -// and the FD for the server to communicate with. -func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { - vfsObj := &vfs.VirtualFilesystem{} - fuseDev := &DeviceFD{} - - if err := vfsObj.Init(system.Ctx); err != nil { - return nil, nil, err - } - - vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef(system.Ctx) - if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { - return nil, nil, err - } - - fsopts := filesystemOptions{ - maxActiveRequests: maxActiveRequests, - } - fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) - if err != nil { - return nil, nil, err - } - - return fs.conn, &fuseDev.vfsfd, nil -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - usermem.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go new file mode 100644 index 000000000..8f220a04b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type directoryFD struct { + fileDescription +} + +// Allocate implements directoryFD.Allocate. +func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EISDIR +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback) error { + fusefs := dir.inode().fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSEReadIn{ + Fh: dir.Fh, + Offset: uint64(atomic.LoadInt64(&dir.off)), + Size: linux.FUSE_PAGE_SIZE, + Flags: dir.statusFlags(), + } + + // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) + if err != nil { + return err + } + + res, err := fusefs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + var out linux.FUSEDirents + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + for _, fuseDirent := range out.Dirents { + nextOff := int64(fuseDirent.Meta.Off) + dirent := vfs.Dirent{ + Name: fuseDirent.Name, + Type: uint8(fuseDirent.Meta.Type), + Ino: fuseDirent.Meta.Ino, + NextOff: nextOff, + } + + if err := callback.Handle(dirent); err != nil { + return err + } + atomic.StoreInt64(&dir.off, nextOff) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go new file mode 100644 index 000000000..83f2816b7 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fileDescription implements vfs.FileDescriptionImpl for fuse. +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + // the file handle used in userspace. + Fh uint64 + + // Nonseekable is indicate cannot perform seek on a file. + Nonseekable bool + + // DirectIO suggest fuse to use direct io operation. + DirectIO bool + + // OpenFlag is the flag returned by open. + OpenFlag uint32 + + // off is the file offset. + off int64 +} + +func (fd *fileDescription) dentry() *kernfs.Dentry { + return fd.vfsfd.Dentry().Impl().(*kernfs.Dentry) +} + +func (fd *fileDescription) inode() *inode { + return fd.dentry().Inode().(*inode) +} + +func (fd *fileDescription) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *fileDescription) statusFlags() uint32 { + return fd.vfsfd.StatusFlags() +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *fileDescription) Release(ctx context.Context) { + // no need to release if FUSE server doesn't implement Open. + conn := fd.inode().fs.conn + if conn.noOpen { + return + } + + in := linux.FUSEReleaseIn{ + Fh: fd.Fh, + Flags: fd.statusFlags(), + } + // TODO(gvisor.dev/issue/3245): add logic when we support file lock owner. + var opcode linux.FUSEOpcode + if fd.inode().Mode().IsDir() { + opcode = linux.FUSE_RELEASEDIR + } else { + opcode = linux.FUSE_RELEASE + } + kernelTask := kernel.TaskFromContext(ctx) + // ignoring errors and FUSE server reply is analogous to Linux's behavior. + req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) + if err != nil { + // No way to invoke Call() with an errored request. + return + } + // The reply will be ignored since no callback is defined in asyncCallBack(). + conn.CallAsync(kernelTask, req) +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + creds := auth.CredentialsFromContext(ctx) + return fd.inode().setAttr(ctx, fs, creds, opts, true, fd.Fh) +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 810819ae4..65786e42a 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -16,24 +16,36 @@ package fuse import ( + "math" "strconv" + "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // Name is the default filesystem name. const Name = "fuse" +// maxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const maxActiveRequestsDefault = 10000 + // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} +// +stateify savable type filesystemOptions struct { // userID specifies the numeric uid of the mount owner. // This option should not be specified by the filesystem owner. @@ -56,9 +68,16 @@ type filesystemOptions struct { // exist at any time. Any further requests will block when trying to // Call the server. maxActiveRequests uint64 + + // maxRead is the max number of bytes to read, + // specified as "max_read" in fs parameters. + // If not specified by user, use math.MaxUint32 as default value. + maxRead uint32 } // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { kernfs.Filesystem devMinor uint32 @@ -69,6 +88,9 @@ type filesystem struct { // opts is the options the fusefs is initialized with. opts *filesystemOptions + + // umounted is true if filesystem.Release() has been called. + umounted bool } // Name implements vfs.FilesystemType.Name. @@ -142,14 +164,29 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Set the maxInFlightRequests option. fsopts.maxActiveRequests = maxActiveRequestsDefault + if maxReadStr, ok := mopts["max_read"]; ok { + delete(mopts, "max_read") + maxRead, err := strconv.ParseUint(maxReadStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr) + return nil, nil, syserror.EINVAL + } + if maxRead < fuseMinMaxRead { + maxRead = fuseMinMaxRead + } + fsopts.maxRead = uint32(maxRead) + } else { + fsopts.maxRead = math.MaxUint32 + } + // Check for unparsed options. if len(mopts) != 0 { - log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) + log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) return nil, nil, syserror.EINVAL } // Create a new FUSE filesystem. - fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) if err != nil { log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) return nil, nil, err @@ -165,26 +202,28 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newInode(creds, fsopts.rootMode) + root := fs.newRootInode(creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } -// NewFUSEFilesystem creates a new FUSE filesystem. -func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { - fs := &filesystem{ - devMinor: devMinor, - opts: opts, - } - - conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests) +// newFUSEFilesystem creates a new FUSE filesystem. +func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + conn, err := newFUSEConnection(ctx, device, opts) if err != nil { log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) return nil, syserror.EINVAL } - fs.conn = conn fuseFD := device.Impl().(*DeviceFD) + + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + conn: conn, + } + + fs.VFSFilesystem().IncRef() fuseFD.fs = fs return fs, nil @@ -192,11 +231,22 @@ func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { + fs.conn.fd.mu.Lock() + + fs.umounted = true + fs.conn.Abort(ctx) + // Notify all the waiters on this fd. + fs.conn.fd.waitQueue.Notify(waiter.EventIn) + + fs.conn.fd.mu.Unlock() + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) fs.Filesystem.Release(ctx) } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { inodeRefs kernfs.InodeAttrs @@ -205,14 +255,50 @@ type inode struct { kernfs.InodeNotSymlink kernfs.OrderedChildren + dentry kernfs.Dentry + + // the owning filesystem. fs is immutable. + fs *filesystem + + // metaDataMu protects the metadata of this inode. + metadataMu sync.Mutex + + nodeID uint64 + locks vfs.FileLocks - dentry kernfs.Dentry + // size of the file. + size uint64 + + // attributeVersion is the version of inode's attributes. + attributeVersion uint64 + + // attributeTime is the remaining vaild time of attributes. + attributeTime uint64 + + // version of the inode. + version uint64 + + // link is result of following a symbolic link. + link string +} + +func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &inode{fs: fs} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.EnableLeakCheck() + i.dentry.Init(i) + i.nodeID = 1 + + return &i.dentry } -func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { - i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) +func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry { + i := &inode{fs: fs, nodeID: nodeID} + creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} + i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() i.dentry.Init(i) @@ -221,14 +307,292 @@ func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *ke } // Open implements kernfs.Inode.Open. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ - SeekEnd: kernfs.SeekEndStaticEntries, - }) +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + isDir := i.InodeAttrs.Mode().IsDir() + // return error if specified to open directory but inode is not a directory. + if !isDir && opts.Mode.IsDir() { + return nil, syserror.ENOTDIR + } + if opts.Flags&linux.O_LARGEFILE == 0 && atomic.LoadUint64(&i.size) > linux.MAX_NON_LFS { + return nil, syserror.EOVERFLOW + } + + var fd *fileDescription + var fdImpl vfs.FileDescriptionImpl + if isDir { + directoryFD := &directoryFD{} + fd = &(directoryFD.fileDescription) + fdImpl = directoryFD + } else { + regularFD := ®ularFileFD{} + fd = &(regularFD.fileDescription) + fdImpl = regularFD + } + // FOPEN_KEEP_CACHE is the defualt flag for noOpen. + fd.OpenFlag = linux.FOPEN_KEEP_CACHE + + // Only send open request when FUSE server support open or is opening a directory. + if !i.fs.conn.noOpen || isDir { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context") + return nil, syserror.EINVAL + } + + // Build the request. + var opcode linux.FUSEOpcode + if isDir { + opcode = linux.FUSE_OPENDIR + } else { + opcode = linux.FUSE_OPEN + } + + in := linux.FUSEOpenIn{Flags: opts.Flags & ^uint32(linux.O_CREAT|linux.O_EXCL|linux.O_NOCTTY)} + if !i.fs.conn.atomicOTrunc { + in.Flags &= ^uint32(linux.O_TRUNC) + } + + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) + if err != nil { + return nil, err + } + + // Send the request and receive the reply. + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return nil, err + } + if err := res.Error(); err == syserror.ENOSYS && !isDir { + i.fs.conn.noOpen = true + } else if err != nil { + return nil, err + } else { + out := linux.FUSEOpenOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return nil, err + } + + // Process the reply. + fd.OpenFlag = out.OpenFlag + if isDir { + fd.OpenFlag &= ^uint32(linux.FOPEN_DIRECT_IO) + } + + fd.Fh = out.Fh + } + } + + // TODO(gvisor.dev/issue/3234): invalidate mmap after implemented it for FUSE Inode + fd.DirectIO = fd.OpenFlag&linux.FOPEN_DIRECT_IO != 0 + fdOptions := &vfs.FileDescriptionOptions{} + if fd.OpenFlag&linux.FOPEN_NONSEEKABLE != 0 { + fdOptions.DenyPRead = true + fdOptions.DenyPWrite = true + fd.Nonseekable = true + } + + // If we don't send SETATTR before open (which is indicated by atomicOTrunc) + // and O_TRUNC is set, update the inode's version number and clean existing data + // by setting the file size to 0. + if i.fs.conn.atomicOTrunc && opts.Flags&linux.O_TRUNC != 0 { + i.fs.conn.mu.Lock() + i.fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, 0) + i.fs.conn.mu.Unlock() + i.attributeTime = 0 + } + + if err := fd.vfsfd.Init(fdImpl, opts.Flags, rp.Mount(), d.VFSDentry(), fdOptions); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// Lookup implements kernfs.Inode.Lookup. +func (i *inode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { + in := linux.FUSELookupIn{Name: name} + return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) +} + +// IterDirents implements kernfs.Inode.IterDirents. +func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return offset, nil +} + +// Valid implements kernfs.Inode.Valid. +func (*inode) Valid(ctx context.Context) bool { + return true +} + +// NewFile implements kernfs.Inode.NewFile. +func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + in := linux.FUSECreateIn{ + CreateMeta: linux.FUSECreateMeta{ + Flags: opts.Flags, + Mode: uint32(opts.Mode) | linux.S_IFREG, + Umask: uint32(kernelTask.FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) +} + +// NewNode implements kernfs.Inode.NewNode. +func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*kernfs.Dentry, error) { + in := linux.FUSEMknodIn{ + MknodMeta: linux.FUSEMknodMeta{ + Mode: uint32(opts.Mode), + Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) +} + +// NewSymlink implements kernfs.Inode.NewSymlink. +func (i *inode) NewSymlink(ctx context.Context, name, target string) (*kernfs.Dentry, error) { + in := linux.FUSESymLinkIn{ + Name: name, + Target: target, + } + return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) +} + +// Unlink implements kernfs.Inode.Unlink. +func (i *inode) Unlink(ctx context.Context, name string, child *kernfs.Dentry) error { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return syserror.EINVAL + } + in := linux.FUSEUnlinkIn{Name: name} + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) + if err != nil { + return err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return err + } + // only return error, discard res. + if err := res.Error(); err != nil { + return err + } + return i.dentry.RemoveChildLocked(name, child) +} + +// NewDir implements kernfs.Inode.NewDir. +func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) { + in := linux.FUSEMkdirIn{ + MkdirMeta: linux.FUSEMkdirMeta{ + Mode: uint32(opts.Mode), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) +} + +// RmDir implements kernfs.Inode.RmDir. +func (i *inode) RmDir(ctx context.Context, name string, child *kernfs.Dentry) error { + fusefs := i.fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSERmDirIn{Name: name} + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) + if err != nil { + return err + } + + res, err := i.fs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + return i.dentry.RemoveChildLocked(name, child) +} + +// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response. +// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP. +func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*kernfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) + if err != nil { + return nil, err + } + res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err } - return fd.VFSFileDescription(), nil + if err := res.Error(); err != nil { + return nil, err + } + out := linux.FUSEEntryOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return nil, err + } + if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { + return nil, syserror.EIO + } + child := i.fs.newInode(out.NodeID, out.Attr) + return child, nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + path, err := i.Readlink(ctx, mnt) + return vfs.VirtualDentry{}, path, err +} + +// Readlink implements kernfs.Inode.Readlink. +func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { + if i.Mode().FileType()&linux.S_IFLNK == 0 { + return "", syserror.EINVAL + } + if len(i.link) == 0 { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") + return "", syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) + if err != nil { + return "", err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return "", err + } + i.link = string(res.data[res.hdr.SizeBytes():]) + if !mnt.Options().ReadOnly { + i.attributeTime = 0 + } + } + return i.link, nil +} + +// getFUSEAttr returns a linux.FUSEAttr of this inode stored in local cache. +// TODO(gvisor.dev/issue/3679): Add support for other fields. +func (i *inode) getFUSEAttr() linux.FUSEAttr { + return linux.FUSEAttr{ + Ino: i.Ino(), + Size: atomic.LoadUint64(&i.size), + Mode: uint32(i.Mode()), + } } // statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The @@ -284,50 +648,92 @@ func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx { return stat } -// Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - fusefs := fs.Impl().(*filesystem) - conn := fusefs.conn - task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) +// getAttr gets the attribute of this inode by issuing a FUSE_GETATTR request +// or read from local cache. It updates the corresponding attributes if +// necessary. +func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions, flags uint32, fh uint64) (linux.FUSEAttr, error) { + attributeVersion := atomic.LoadUint64(&i.fs.conn.attributeVersion) + + // TODO(gvisor.dev/issue/3679): send the request only if + // - invalid local cache for fields specified in the opts.Mask + // - forced update + // - i.attributeTime expired + // If local cache is still valid, return local cache. + // Currently we always send a request, + // and we always set the metadata with the new result, + // unless attributeVersion has changed. + + task := kernel.TaskFromContext(ctx) if task == nil { log.Warningf("couldn't get kernel task from context") - return linux.Statx{}, syserror.EINVAL + return linux.FUSEAttr{}, syserror.EINVAL } - var in linux.FUSEGetAttrIn - // We don't set any attribute in the request, because in VFS2 fstat(2) will - // finally be translated into vfs.FilesystemImpl.StatAt() (see - // pkg/sentry/syscalls/linux/vfs2/stat.go), resulting in the same flow - // as stat(2). Thus GetAttrFlags and Fh variable will never be used in VFS2. - req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.Ino(), linux.FUSE_GETATTR, &in) + creds := auth.CredentialsFromContext(ctx) + + in := linux.FUSEGetAttrIn{ + GetAttrFlags: flags, + Fh: fh, + } + req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) if err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } - res, err := conn.Call(task, req) + res, err := i.fs.conn.Call(task, req) if err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } if err := res.Error(); err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } var out linux.FUSEGetAttrOut if err := res.UnmarshalPayload(&out); err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err + } + + // Local version is newer, return the local one. + // Skip the update. + if attributeVersion != 0 && atomic.LoadUint64(&i.attributeVersion) > attributeVersion { + return i.getFUSEAttr(), nil } - // Set all metadata into kernfs.InodeAttrs. - if err := i.SetStat(ctx, fs, creds, vfs.SetStatOptions{ - Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, fusefs.devMinor), + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { + return linux.FUSEAttr{}, err + } + + // Set the size if no error (after SetStat() check). + atomic.StoreUint64(&i.size, out.Attr.Size) + + return out.Attr, nil +} + +// reviseAttr attempts to update the attributes for internal purposes +// by calling getAttr with a pre-specified mask. +// Used by read, write, lseek. +func (i *inode) reviseAttr(ctx context.Context, flags uint32, fh uint64) error { + // Never need atime for internal purposes. + _, err := i.getAttr(ctx, i.fs.VFSFilesystem(), vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS &^ linux.STATX_ATIME, + }, flags, fh) + return err +} + +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + attr, err := i.getAttr(ctx, fs, opts, 0, 0) + if err != nil { return linux.Statx{}, err } - return statFromFUSEAttr(out.Attr, opts.Mask, fusefs.devMinor), nil + return statFromFUSEAttr(attr, opts.Mask, i.fs.devMinor), nil } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(context.Context) { i.inodeRefs.DecRef(i.Destroy) } @@ -337,3 +743,84 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // TODO(gvisor.dev/issues/3413): Complete the implementation of statfs. return vfs.GenericStatFS(linux.FUSE_SUPER_MAGIC), nil } + +// fattrMaskFromStats converts vfs.SetStatOptions.Stat.Mask to linux stats mask +// aligned with the attribute mask defined in include/linux/fs.h. +func fattrMaskFromStats(mask uint32) uint32 { + var fuseAttrMask uint32 + maskMap := map[uint32]uint32{ + linux.STATX_MODE: linux.FATTR_MODE, + linux.STATX_UID: linux.FATTR_UID, + linux.STATX_GID: linux.FATTR_GID, + linux.STATX_SIZE: linux.FATTR_SIZE, + linux.STATX_ATIME: linux.FATTR_ATIME, + linux.STATX_MTIME: linux.FATTR_MTIME, + linux.STATX_CTIME: linux.FATTR_CTIME, + } + for statxMask, fattrMask := range maskMap { + if mask&statxMask != 0 { + fuseAttrMask |= fattrMask + } + } + return fuseAttrMask +} + +// SetStat implements kernfs.Inode.SetStat. +func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return i.setAttr(ctx, fs, creds, opts, false, 0) +} + +func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions, useFh bool, fh uint64) error { + conn := i.fs.conn + task := kernel.TaskFromContext(ctx) + if task == nil { + log.Warningf("couldn't get kernel task from context") + return syserror.EINVAL + } + + // We should retain the original file type when assigning new mode. + fileType := uint16(i.Mode()) & linux.S_IFMT + fattrMask := fattrMaskFromStats(opts.Stat.Mask) + if useFh { + fattrMask |= linux.FATTR_FH + } + in := linux.FUSESetAttrIn{ + Valid: fattrMask, + Fh: fh, + Size: opts.Stat.Size, + Atime: uint64(opts.Stat.Atime.Sec), + Mtime: uint64(opts.Stat.Mtime.Sec), + Ctime: uint64(opts.Stat.Ctime.Sec), + AtimeNsec: opts.Stat.Atime.Nsec, + MtimeNsec: opts.Stat.Mtime.Nsec, + CtimeNsec: opts.Stat.Ctime.Nsec, + Mode: uint32(fileType | opts.Stat.Mode), + UID: opts.Stat.UID, + GID: opts.Stat.GID, + } + req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) + if err != nil { + return err + } + + res, err := conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + out := linux.FUSEGetAttrOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), + }); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go new file mode 100644 index 000000000..625d1547f --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -0,0 +1,242 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// ReadInPages sends FUSE_READ requests for the size after round it up to +// a multiple of page size, blocks on it for reply, processes the reply +// and returns the payload (or joined payloads) as a byte slice. +// This is used for the general purpose reading. +// We do not support direct IO (which read the exact number of bytes) +// at this moment. +func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off uint64, size uint32) ([][]byte, uint32, error) { + attributeVersion := atomic.LoadUint64(&fs.conn.attributeVersion) + + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return nil, 0, syserror.EINVAL + } + + // Round up to a multiple of page size. + readSize, _ := usermem.PageRoundUp(uint64(size)) + + // One request cannnot exceed either maxRead or maxPages. + maxPages := fs.conn.maxRead >> usermem.PageShift + if maxPages > uint32(fs.conn.maxPages) { + maxPages = uint32(fs.conn.maxPages) + } + + var outs [][]byte + var sizeRead uint32 + + // readSize is a multiple of usermem.PageSize. + // Always request bytes as a multiple of pages. + pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift) + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEReadIn{ + Fh: fd.Fh, + LockOwner: 0, // TODO(gvisor.dev/issue/3245): file lock + ReadFlags: 0, // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + Flags: fd.statusFlags(), + } + + // This loop is intended for fragmented read where the bytes to read is + // larger than either the maxPages or maxRead. + // For the majority of reads with normal size, this loop should only + // execute once. + for pagesRead < pagesToRead { + pagesCanRead := pagesToRead - pagesRead + if pagesCanRead > maxPages { + pagesCanRead = maxPages + } + + in.Offset = off + (uint64(pagesRead) << usermem.PageShift) + in.Size = pagesCanRead << usermem.PageShift + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) + if err != nil { + return nil, 0, err + } + + // TODO(gvisor.dev/issue/3247): support async read. + + res, err := fs.conn.Call(t, req) + if err != nil { + return nil, 0, err + } + if err := res.Error(); err != nil { + return nil, 0, err + } + + // Not enough bytes in response, + // either we reached EOF, + // or the FUSE server sends back a response + // that cannot even fit the hdr. + if len(res.data) <= res.hdr.SizeBytes() { + // We treat both case as EOF here for now + // since there is no reliable way to detect + // the over-short hdr case. + break + } + + // Directly using the slice to avoid extra copy. + out := res.data[res.hdr.SizeBytes():] + + outs = append(outs, out) + sizeRead += uint32(len(out)) + + pagesRead += pagesCanRead + } + + defer fs.ReadCallback(ctx, fd, off, size, sizeRead, attributeVersion) + + // No bytes returned: offset >= EOF. + if len(outs) == 0 { + return nil, 0, io.EOF + } + + return outs, sizeRead, nil +} + +// ReadCallback updates several information after receiving a read response. +// Due to readahead, sizeRead can be larger than size. +func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off uint64, size uint32, sizeRead uint32, attributeVersion uint64) { + // TODO(gvisor.dev/issue/3247): support async read. + // If this is called by an async read, correctly process it. + // May need to update the signature. + + i := fd.inode() + // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + + // Reached EOF. + if sizeRead < size { + // TODO(gvisor.dev/issue/3630): If we have writeback cache, then we need to fill this hole. + // Might need to update the buf to be returned from the Read(). + + // Update existing size. + newSize := off + uint64(sizeRead) + fs.conn.mu.Lock() + if attributeVersion == i.attributeVersion && newSize < atomic.LoadUint64(&i.size) { + fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, newSize) + } + fs.conn.mu.Unlock() + } +} + +// Write sends FUSE_WRITE requests and return the bytes +// written according to the response. +// +// Preconditions: len(data) == size. +func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, size uint32, data []byte) (uint32, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return 0, syserror.EINVAL + } + + // One request cannnot exceed either maxWrite or maxPages. + maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift + if maxWrite > fs.conn.maxWrite { + maxWrite = fs.conn.maxWrite + } + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEWriteIn{ + Fh: fd.Fh, + // TODO(gvisor.dev/issue/3245): file lock + LockOwner: 0, + // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + // TODO(gvisor.dev/issue/3237): |= linux.FUSE_WRITE_CACHE (not added yet) + WriteFlags: 0, + Flags: fd.statusFlags(), + } + + var written uint32 + + // This loop is intended for fragmented write where the bytes to write is + // larger than either the maxWrite or maxPages or when bigWrites is false. + // Unless a small value for max_write is explicitly used, this loop + // is expected to execute only once for the majority of the writes. + for written < size { + toWrite := size - written + + // Limit the write size to one page. + // Note that the bigWrites flag is obsolete, + // latest libfuse always sets it on. + if !fs.conn.bigWrites && toWrite > usermem.PageSize { + toWrite = usermem.PageSize + } + + // Limit the write size to maxWrite. + if toWrite > maxWrite { + toWrite = maxWrite + } + + in.Offset = off + uint64(written) + in.Size = toWrite + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + if err != nil { + return 0, err + } + + req.payload = data[written : written+toWrite] + + // TODO(gvisor.dev/issue/3247): support async write. + + res, err := fs.conn.Call(t, req) + if err != nil { + return 0, err + } + if err := res.Error(); err != nil { + return 0, err + } + + out := linux.FUSEWriteOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return 0, err + } + + // Write more than requested? EIO. + if out.Size > toWrite { + return 0, syserror.EIO + } + + written += out.Size + + // Break if short write. Not necessarily an error. + if out.Size != toWrite { + break + } + } + + return written, nil +} diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go new file mode 100644 index 000000000..5bdd096c3 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/regular_file.go @@ -0,0 +1,230 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "math" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type regularFileFD struct { + fileDescription + + // off is the file offset. + off int64 + // offMu protects off. + offMu sync.Mutex +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if offset < 0 { + return 0, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + + size := dst.NumBytes() + if size == 0 { + // Early return if count is 0. + return 0, nil + } else if size > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, syserror.EINVAL + } + + // TODO(gvisor.dev/issue/3678): Add direct IO support. + + inode := fd.inode() + + // Reading beyond EOF, update file size if outdated. + if uint64(offset+size) > atomic.LoadUint64(&inode.size) { + if err := inode.reviseAttr(ctx, linux.FUSE_GETATTR_FH, fd.Fh); err != nil { + return 0, err + } + // If the offset after update is still too large, return error. + if uint64(offset) >= atomic.LoadUint64(&inode.size) { + return 0, io.EOF + } + } + + // Truncate the read with updated file size. + fileSize := atomic.LoadUint64(&inode.size) + if uint64(offset+size) > fileSize { + size = int64(fileSize) - offset + } + + buffers, n, err := inode.fs.ReadInPages(ctx, fd, uint64(offset), uint32(size)) + if err != nil { + return 0, err + } + + // TODO(gvisor.dev/issue/3237): support indirect IO (e.g. caching), + // store the bytes that were read ahead. + + // Update the number of bytes to copy for short read. + if n < uint32(size) { + size = int64(n) + } + + // Copy the bytes read to the dst. + // This loop is intended for fragmented reads. + // For the majority of reads, this loop only execute once. + var copied int64 + for _, buffer := range buffers { + toCopy := int64(len(buffer)) + if copied+toCopy > size { + toCopy = size - copied + } + cp, err := dst.DropFirst64(copied).CopyOut(ctx, buffer[:toCopy]) + if err != nil { + return 0, err + } + if int64(cp) != toCopy { + return 0, syserror.EIO + } + copied += toCopy + } + + return copied, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.offMu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.offMu.Lock() + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off + fd.offMu.Unlock() + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { + if offset < 0 { + return 0, offset, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + inode := fd.inode() + inode.metadataMu.Lock() + defer inode.metadataMu.Unlock() + + // If the file is opened with O_APPEND, update offset to file size. + // Note: since our Open() implements the interface of kernfs, + // and kernfs currently does not support O_APPEND, this will never + // be true before we switch out from kernfs. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking inode.metadataMu is sufficient for reading size + offset = int64(inode.size) + } + + srclen := src.NumBytes() + + if srclen > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, offset, syserror.EINVAL + } + if end := offset + srclen; end < offset { + // Overflow. + return 0, offset, syserror.EINVAL + } + + srclen, err = vfs.CheckLimit(ctx, offset, srclen) + if err != nil { + return 0, offset, err + } + + if srclen == 0 { + // Return before causing any side effects. + return 0, offset, nil + } + + src = src.TakeFirst64(srclen) + + // TODO(gvisor.dev/issue/3237): Add cache support: + // buffer cache. Ideally we write from src to our buffer cache first. + // The slice passed to fs.Write() should be a slice from buffer cache. + data := make([]byte, srclen) + // Reason for making a copy here: connection.Call() blocks on kerneltask, + // which in turn acquires mm.activeMu lock. Functions like CopyInTo() will + // attemp to acquire the mm.activeMu lock as well -> deadlock. + // We must finish reading from the userspace memory before + // t.Block() deactivates it. + cp, err := src.CopyIn(ctx, data) + if err != nil { + return 0, offset, err + } + if int64(cp) != srclen { + return 0, offset, syserror.EIO + } + + n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data) + if err != nil { + return 0, offset, err + } + + if n == 0 { + // We have checked srclen != 0 previously. + // If err == nil, then it's a short write and we return EIO. + return 0, offset, syserror.EIO + } + + written = int64(n) + finalOff = offset + written + + if finalOff > int64(inode.size) { + atomic.StoreUint64(&inode.size, uint64(finalOff)) + atomic.AddUint64(&inode.fs.conn.attributeVersion, 1) + } + + return +} diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go new file mode 100644 index 000000000..7fa00569b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE +// server may implement an older version of FUSE protocol, which contains a +// linux.FUSEInitOut with less attributes. +// +// Dynamically-sized objects cannot be marshalled. +type fuseInitRes struct { + marshal.StubMarshallable + + // initOut contains the response from the FUSE server. + initOut linux.FUSEInitOut + + // initLen is the total length of bytes of the response. + initLen uint32 +} + +// UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. +func (r *fuseInitRes) UnmarshalBytes(src []byte) { + out := &r.initOut + + // Introduced before FUSE kernel version 7.13. + out.Major = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + + // Introduced in FUSE kernel version 7.23. + if len(src) >= 4 { + out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + // Introduced in FUSE kernel version 7.28. + if len(src) >= 2 { + out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + } +} + +// SizeBytes is the size of the payload of the FUSE_INIT response. +func (r *fuseInitRes) SizeBytes() int { + return int(r.initLen) +} + +// Ordinary requests have even IDs, while interrupts IDs are odd. +// Used to increment the unique ID for each FUSE request. +var reqIDStep uint64 = 2 + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte + + // payload for this request: extra bytes to write after + // the data slice. Used by FUSE_WRITE. + payload []byte + + // If this request is async. + async bool + // If we don't care its response. + // Manually set by the caller. + noReply bool +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + hdr.MarshalBytes(buf[:hdrLen]) + payload.MarshalBytes(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte + + // If this request is async. + async bool +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(req *Request) *futureResponse { + return &futureResponse{ + opcode: req.hdr.Opcode, + ch: make(chan struct{}), + async: req.async, + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // Return directly for async requests. + if f.async { + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// DataLen returns the size of the response without the header. +func (r *Response) DataLen() uint32 { + return r.hdr.Len - uint32(r.hdr.SizeBytes()) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + // The response data is empty unless there is some payload. And so, doesn't + // need to be unmarshalled. + if r.data == nil { + return nil + } + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + m.UnmarshalBytes(r.data[hdrLen:]) + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go new file mode 100644 index 000000000..e1d9e3365 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(system.Ctx); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef(system.Ctx) + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +type testPayload struct { + marshal.StubMarshallable + data uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 91d2ae199..18c884b59 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -117,11 +117,12 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { d.syntheticChildren++ } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 dirents []vfs.Dirent } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 5d0f487db..94d96261b 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -1026,7 +1026,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open // step is required even if !d.cachedMetadataAuthoritative() because // d.mappings has to be updated. // d.metadataMu has already been acquired if trunc == true. - d.updateFileSizeLocked(0) + d.updateSizeLocked(0) if d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() @@ -1311,6 +1311,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if !renamed.isDir() { return syserror.EISDIR } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } } else { if rp.MustBeDir() || renamed.isDir() { return syserror.ENOTDIR @@ -1361,14 +1364,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { + oldParent.decRefLocked() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ } + renamed.parent = newParent } - renamed.parent = newParent renamed.name = newName if newParent.children == nil { newParent.children = make(map[string]*dentry) @@ -1412,11 +1416,11 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil { - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + err = d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) @@ -1491,7 +1495,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return fs.unlinkAt(ctx, rp, false /* dir */) } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() @@ -1519,8 +1523,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -1528,11 +1532,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.listxattr(ctx, rp.Credentials(), size) + return d.listXattr(ctx, rp.Credentials(), size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -1540,11 +1544,11 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return "", err } - return d.getxattr(ctx, rp.Credentials(), &opts) + return d.getXattr(ctx, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) @@ -1552,18 +1556,18 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + err = d.setXattr(ctx, rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) @@ -1571,11 +1575,11 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + err = d.removeXattr(ctx, rp.Credentials(), name) + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + if err != nil { return err } - fs.renameMuRUnlockAndCheckCaching(ctx, &ds) d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 57bff1789..8608471f8 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -62,9 +62,13 @@ import ( const Name = "9p" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -77,7 +81,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client + client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -95,7 +99,7 @@ type filesystem struct { // reference count (such that it is usable as vfs.ResolvingPath.Start() or // is reachable from its children), or if it is a child dentry (such that // it is reachable from its parent). - renameMu sync.RWMutex + renameMu sync.RWMutex `state:"nosave"` // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) @@ -107,7 +111,7 @@ type filesystem struct { // syncableDentries contains all dentries in this filesystem for which // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. // These fields are protected by syncMu. - syncMu sync.Mutex + syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} @@ -120,6 +124,8 @@ type filesystem struct { // dentries, it comes from QID.Path from the 9P server. Synthetic dentries // have have their inodeNumber generated sequentially, with the MSB reserved to // prevent conflicts with regular dentries. +// +// +stateify savable type inodeNumber uint64 // Reserve MSB for synthetic mounts. @@ -132,6 +138,7 @@ func inoFromPath(path uint64) inodeNumber { return inodeNumber(path &^ syntheticInoMask) } +// +stateify savable type filesystemOptions struct { // "Standard" 9P options. fd int @@ -177,6 +184,8 @@ type filesystemOptions struct { // InteropMode controls the client's interaction with other remote filesystem // users. +// +// +stateify savable type InteropMode uint32 const ( @@ -195,11 +204,7 @@ const ( // and consistent with Linux's semantics (in particular, it is not always // possible for clients to set arbitrary atimes and mtimes depending on the // remote filesystem implementation, and never possible for clients to set - // arbitrary ctimes.) If a dentry containing a client-defined atime or - // mtime is evicted from cache, client timestamps will be sent to the - // remote filesystem on a best-effort basis to attempt to ensure that - // timestamps will be preserved when another dentry representing the same - // file is instantiated. + // arbitrary ctimes.) InteropModeExclusive InteropMode = iota // InteropModeWritethrough is appropriate when there are read-only users of @@ -239,6 +244,8 @@ const ( // InternalFilesystemOptions may be passed as // vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem. +// +// +stateify savable type InternalFilesystemOptions struct { // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in @@ -538,6 +545,8 @@ func (fs *filesystem) Release(ctx context.Context) { } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -567,7 +576,7 @@ type dentry struct { // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file + file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -579,7 +588,7 @@ type dentry struct { cached bool dentryEntry - dirMu sync.Mutex + dirMu sync.Mutex `state:"nosave"` // If this dentry represents a directory, children contains: // @@ -611,7 +620,7 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex + metadataMu sync.Mutex `state:"nosave"` ino inodeNumber // immutable mode uint32 // type is immutable, perms are mutable uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic @@ -642,7 +651,7 @@ type dentry struct { // other metadata fields. nlink uint32 - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` // If this dentry represents a regular file, mappings tracks mappings of // the file into memmap.MappingSpaces. mappings is protected by mapsMu. @@ -666,12 +675,12 @@ type dentry struct { // either p9.File transitions from closed (isNil() == true) to open // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. - handleMu sync.RWMutex - readFile p9file - writeFile p9file + handleMu sync.RWMutex `state:"nosave"` + readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. hostFD int32 - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` // If this dentry represents a regular file that is client-cached, cache // maps offsets into the cached file to offsets into @@ -837,7 +846,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.updateFileSizeLocked(attr.Size) + d.updateSizeLocked(attr.Size) } } @@ -991,7 +1000,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // d.size should be kept up to date, and privatized // copy-on-write mappings of truncated pages need to be // invalidated, even if InteropModeShared is in effect. - d.updateFileSizeLocked(stat.Size) + d.updateSizeLocked(stat.Size) } } if d.fs.opts.interop == InteropModeShared { @@ -1028,8 +1037,31 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } +// doAllocate performs an allocate operation on d. Note that d.metadataMu will +// be held when allocate is called. +func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate func() error) error { + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Allocating a smaller size is a noop. + size := offset + length + if d.cachedMetadataAuthoritative() && size <= d.size { + return nil + } + + err := allocate() + if err != nil { + return err + } + d.updateSizeLocked(size) + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // Preconditions: d.metadataMu must be locked. -func (d *dentry) updateFileSizeLocked(newSize uint64) { +func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() oldSize := d.size atomic.StoreUint64(&d.size, newSize) @@ -1067,6 +1099,21 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + // We only support xattrs prefixed with "user." (see b/148380782). Currently, + // there is no need to expose any other xattrs through a gofer. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) } @@ -1300,30 +1347,19 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.handleMu.Unlock() if !d.file.isNil() { - if !d.isDeleted() { - // Write dirty timestamps back to the remote filesystem. - atimeDirty := atomic.LoadUint32(&d.atimeDirty) != 0 - mtimeDirty := atomic.LoadUint32(&d.mtimeDirty) != 0 - if atimeDirty || mtimeDirty { - atime := atomic.LoadInt64(&d.atime) - mtime := atomic.LoadInt64(&d.mtime) - if err := d.file.setAttr(ctx, p9.SetAttrMask{ - ATime: atimeDirty, - ATimeNotSystemTime: atimeDirty, - MTime: mtimeDirty, - MTimeNotSystemTime: mtimeDirty, - }, p9.SetAttr{ - ATimeSeconds: uint64(atime / 1e9), - ATimeNanoSeconds: uint64(atime % 1e9), - MTimeSeconds: uint64(mtime / 1e9), - MTimeNanoSeconds: uint64(mtime % 1e9), - }); err != nil { - log.Warningf("gofer.dentry.destroyLocked: failed to write dirty timestamps back: %v", err) - } - } + // Note that it's possible that d.atimeDirty or d.mtimeDirty are true, + // i.e. client and server timestamps may differ (because e.g. a client + // write was serviced by the page cache, and only written back to the + // remote file later). Ideally, we'd write client timestamps back to + // the remote filesystem so that timestamps for a new dentry + // instantiated for the same file would remain coherent. Unfortunately, + // this turns out to be too expensive in many cases, so for now we + // don't do this. + if err := d.file.close(ctx); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) } - d.file.close(ctx) d.file = p9file{} + // Remove d from the set of syncable dentries. d.fs.syncMu.Lock() delete(d.fs.syncableDentries, d) @@ -1351,9 +1387,7 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -// We only support xattrs prefixed with "user." (see b/148380782). Currently, -// there is no need to expose any other xattrs through a gofer. -func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { +func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { if d.file.isNil() || !d.userXattrSupported() { return nil, nil } @@ -1363,6 +1397,7 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui } xattrs := make([]string, 0, len(xattrMap)) for x := range xattrMap { + // We only support xattrs in the user.* namespace. if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { xattrs = append(xattrs, x) } @@ -1370,51 +1405,33 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui return xattrs, nil } -func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { +func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { if d.file.isNil() { return "", syserror.ENODATA } - if err := d.checkPermissions(creds, vfs.MayRead); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return "", syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return "", syserror.ENODATA - } return d.file.getXattr(ctx, opts.Name, opts.Size) } -func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error { +func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } -func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error { +func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.removeXattr(ctx, name) } @@ -1623,12 +1640,14 @@ func (d *dentry) decLinks() { // fileDescription is embedded by gofer implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once + lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } func (fd *fileDescription) filesystem() *filesystem { @@ -1666,30 +1685,30 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size) +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size) } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts) +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.dentry().getXattr(ctx, auth.CredentialsFromContext(ctx), &opts) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.setXattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { return err } d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.removeXattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { return err } d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 104157512..a9ebe1206 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -25,6 +25,8 @@ import ( // handle represents a remote "open file descriptor", consisting of an opened // fid (p9.File) and optionally a host file descriptor. +// +// These are explicitly not savable. type handle struct { file p9file fd int32 // -1 if unavailable diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 87f0b877f..21b4a96fe 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -127,6 +127,13 @@ func (f p9file) close(ctx context.Context) error { return err } +func (f p9file) setAttrClose(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error { + ctx.UninterruptibleSleepStart(false) + err := f.file.SetAttrClose(valid, attr) + ctx.UninterruptibleSleepFinish(false) + return err +} + func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { ctx.UninterruptibleSleepStart(false) fdobj, qid, iounit, err := f.file.Open(flags) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index a2e9342d5..eeaf6e444 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -39,11 +39,12 @@ func (d *dentry) isRegularFile() bool { return d.fileType() == linux.S_IFREG } +// +stateify savable type regularFileFD struct { fileDescription // off is the file offset. off is protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 } @@ -79,28 +80,11 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { // Allocate implements vfs.FileDescriptionImpl.Allocate. func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - - // Allocating a smaller size is a noop. - size := offset + length - if d.cachedMetadataAuthoritative() && size <= d.size { - return nil - } - - d.handleMu.RLock() - err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) - d.handleMu.RUnlock() - if err != nil { - return err - } - d.dataMu.Lock() - atomic.StoreUint64(&d.size, size) - d.dataMu.Unlock() - if d.cachedMetadataAuthoritative() { - d.touchCMtimeLocked() - } - return nil + return d.doAllocate(ctx, offset, length, func() error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) } // PRead implements vfs.FileDescriptionImpl.PRead. @@ -915,6 +899,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { // dentryPlatformFile is only used when a host FD representing the remote file // is available (i.e. dentry.hostFD >= 0), and that FD is used for application // memory mappings (i.e. !filesystem.opts.forcePageCache). +// +// +stateify savable type dentryPlatformFile struct { *dentry @@ -927,7 +913,7 @@ type dentryPlatformFile struct { hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once + hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 85d2bee72..326b940a7 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -36,12 +36,14 @@ func (d *dentry) isSocket() bool { // An endpoint's lifetime is the time between when filesystem.BoundEndpointAt() // is called and either BoundEndpoint.BidirectionalConnect or // BoundEndpoint.UnidirectionalConnect is called. +// +// +stateify savable type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry // file is the p9 file that contains a single unopened fid. - file p9.File + file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // path is the sentry path where this endpoint is bound. path string diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 3c39aa9b7..71581736c 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -33,11 +34,13 @@ import ( // special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is // in effect) regular files. specialFileFD differs from regularFileFD by using // per-FD handles instead of shared per-dentry handles, and never buffering I/O. +// +// +stateify savable type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle + handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -55,7 +58,7 @@ type specialFileFD struct { queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 } @@ -135,6 +138,16 @@ func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { fd.fileDescription.EventUnregister(e) } +func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + if fd.isRegularFile { + d := fd.dentry() + return d.doAllocate(ctx, offset, length, func() error { + return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) + } + return fd.FileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { @@ -235,11 +248,12 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off d.touchCMtime() } buf := make([]byte, src.NumBytes()) - // Don't do partial writes if we get a partial read from src. - if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, offset, err + copied, copyErr := src.CopyIn(ctx, buf) + if copied == 0 && copyErr != nil { + // Only return the error if we didn't get any data. + return 0, offset, copyErr } - n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) + n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } @@ -256,7 +270,10 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off atomic.StoreUint64(&d.size, uint64(offset)) } } - return int64(n), offset, err + if err != nil { + return int64(n), offset, err + } + return int64(n), offset, copyErr } // Write implements vfs.FileDescriptionImpl.Write. diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index be1c88c82..56bcf9bdb 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -49,6 +49,7 @@ go_library( "//pkg/fspath", "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", @@ -72,7 +73,6 @@ go_library( "//pkg/unet", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 7561f821c..ffe4ddb32 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -58,7 +58,7 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( canMap: fileType == linux.S_IFREG, } i.pf.inode = i - i.refs.EnableLeakCheck() + i.EnableLeakCheck() // Non-seekable files can't be memory mapped, assert this. if !i.seekable && i.canMap { @@ -126,7 +126,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // For simplicity, fileDescription.offset is set to 0. Technically, we // should only set to 0 on files that are not seekable (sockets, pipes, // etc.), and use the offset from the host fd otherwise when importing. - return i.open(ctx, d.VFSDentry(), mnt, flags) + return i.open(ctx, d, mnt, flags) } // ImportFD sets up and returns a vfs.FileDescription from a donated fd. @@ -137,14 +137,16 @@ func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs } // filesystemType implements vfs.FilesystemType. +// +// +stateify savable type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("host.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. func (filesystemType) Name() string { return "none" } @@ -166,6 +168,8 @@ func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) { } // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -185,6 +189,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { kernfs.InodeNoStatFS kernfs.InodeNotDirectory @@ -193,7 +199,7 @@ type inode struct { locks vfs.FileLocks // When the reference count reaches zero, the host fd is closed. - refs inodeRefs + inodeRefs // hostFD contains the host fd that this file was originally created from, // which must be available at time of restore. @@ -233,7 +239,7 @@ type inode struct { canMap bool // mapsMu protects mappings. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` // If canMap is true, mappings tracks mappings of hostFD into // memmap.MappingSpaces. @@ -243,7 +249,7 @@ type inode struct { pf inodePlatformFile } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -252,7 +258,7 @@ func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, a return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid)) } -// Mode implements kernfs.Inode. +// Mode implements kernfs.Inode.Mode. func (i *inode) Mode() linux.FileMode { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -263,7 +269,7 @@ func (i *inode) Mode() linux.FileMode { return linux.FileMode(s.Mode) } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL @@ -376,7 +382,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { }, nil } -// SetStat implements kernfs.Inode. +// SetStat implements kernfs.Inode.SetStat. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { s := &opts.Stat @@ -435,19 +441,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre return nil } -// IncRef implements kernfs.Inode. -func (i *inode) IncRef() { - i.refs.IncRef() -} - -// TryIncRef implements kernfs.Inode. -func (i *inode) TryIncRef() bool { - return i.refs.TryIncRef() -} - -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { - i.refs.DecRef(func() { + i.inodeRefs.DecRef(func() { if i.wouldBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } @@ -457,16 +453,16 @@ func (i *inode) DecRef(ctx context.Context) { }) } -// Open implements kernfs.Inode. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/. if i.Mode().FileType() == linux.S_IFSOCK { return nil, syserror.ENXIO } - return i.open(ctx, vfsd, rp.Mount(), opts.Flags) + return i.open(ctx, d, rp.Mount(), opts.Flags) } -func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { return nil, err @@ -490,17 +486,17 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u return nil, err } // Currently, we only allow Unix sockets to be imported. - return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks) + return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d.VFSDentry(), &i.locks) case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR: if i.isTTY { fd := &TTYFileDescription{ fileDescription: fileDescription{inode: i}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil @@ -509,7 +505,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u fd := &fileDescription{inode: i} fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil @@ -521,6 +517,8 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u } // fileDescription is embedded by host fd implementations of FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -535,40 +533,35 @@ type fileDescription struct { inode *inode // offsetMu protects offset. - offsetMu sync.Mutex + offsetMu sync.Mutex `state:"nosave"` // offset specifies the current file offset. It is only meaningful when // inode.seekable is true. offset int64 } -// SetStat implements vfs.FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts) } -// Stat implements vfs.FileDescriptionImpl. +// Stat implements vfs.FileDescriptionImpl.Stat. func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } -// Release implements vfs.FileDescriptionImpl. +// Release implements vfs.FileDescriptionImpl.Release. func (f *fileDescription) Release(context.Context) { // noop } -// Allocate implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { - if !f.inode.seekable { - return syserror.ESPIPE - } - - // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds. - return syserror.EOPNOTSUPP + return unix.Fallocate(f.inode.hostFD, uint32(mode), int64(offset), int64(length)) } -// PRead implements FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -578,7 +571,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags) } -// Read implements FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -615,7 +608,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off return int64(n), err } -// PWrite implements FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { if !f.inode.seekable { return 0, syserror.ESPIPE @@ -624,7 +617,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of return f.writeToHostFD(ctx, src, offset, opts.Flags) } -// Write implements FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode if !i.seekable { @@ -672,7 +665,7 @@ func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSeque return int64(n), err } -// Seek implements FileDescriptionImpl. +// Seek implements vfs.FileDescriptionImpl.Seek. // // Note that we do not support seeking on directories, since we do not even // allow directory fds to be imported at all. @@ -737,13 +730,13 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return f.offset, nil } -// Sync implements FileDescriptionImpl. +// Sync implements vfs.FileDescriptionImpl.Sync. func (f *fileDescription) Sync(context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } -// ConfigureMMap implements FileDescriptionImpl. +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { if !f.inode.canMap { return syserror.ENODEV diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 65d3af38c..b51a17bed 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -27,11 +27,13 @@ import ( // cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. +// +// +stateify savable type inodePlatformFile struct { *inode // fdRefsMu protects fdRefs. - fdRefsMu sync.Mutex + fdRefsMu sync.Mutex `state:"nosave"` // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. @@ -41,7 +43,7 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once + fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go index 35ded24bc..c0bf45f08 100644 --- a/pkg/sentry/fsimpl/host/socket_unsafe.go +++ b/pkg/sentry/fsimpl/host/socket_unsafe.go @@ -63,10 +63,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index 7a9be4b97..f5c596fec 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -25,11 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // TTYFileDescription implements vfs.FileDescriptionImpl for a host file // descriptor that wraps a TTY FD. +// +// +stateify savable type TTYFileDescription struct { fileDescription @@ -76,7 +78,7 @@ func (t *TTYFileDescription) Release(ctx context.Context) { t.fileDescription.Release(ctx) } -// PRead implements vfs.FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -94,7 +96,7 @@ func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence, return t.fileDescription.PRead(ctx, dst, offset, opts) } -// Read implements vfs.FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -112,7 +114,7 @@ func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, o return t.fileDescription.Read(ctx, dst, opts) } -// PWrite implements vfs.FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -127,7 +129,7 @@ func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, return t.fileDescription.PWrite(ctx, src, offset, opts) } -// Write implements vfs.FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -142,7 +144,7 @@ func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, return t.fileDescription.Write(ctx, src, opts) } -// Ioctl implements vfs.FileDescriptionImpl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 637dca70c..5e91e0536 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -83,6 +83,7 @@ go_library( "slot_list.go", "static_directory_refs.go", "symlink.go", + "synthetic_directory.go", ], visibility = ["//pkg/sentry:internal"], deps = [ diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 1ee089620..b929118b1 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -56,9 +56,9 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint } // Open implements Inode.Open. -func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &DynamicBytesFD{} - if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil { + if err := fd.Init(rp.Mount(), d, f.data, &f.locks, opts.Flags); err != nil { return nil, err } return &fd.vfsfd, nil @@ -87,12 +87,12 @@ type DynamicBytesFD struct { } // Init initializes a DynamicBytesFD. -func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { +func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error { fd.LockFD.Init(locks) - if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return err } - fd.inode = d.Impl().(*Dentry).inode + fd.inode = d.inode fd.SetDataSource(data) return nil } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 6518ff5cd..0a4cd4057 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -29,6 +29,8 @@ import ( ) // SeekEndConfig describes the SEEK_END behaviour for FDs. +// +// +stateify savable type SeekEndConfig int // Constants related to SEEK_END behaviour for FDs. @@ -41,6 +43,8 @@ const ( ) // GenericDirectoryFDOptions contains configuration for a GenericDirectoryFD. +// +// +stateify savable type GenericDirectoryFDOptions struct { SeekEnd SeekEndConfig } @@ -56,6 +60,8 @@ type GenericDirectoryFDOptions struct { // Must be initialize with Init before first use. // // Lock ordering: mu => children.mu. +// +// +stateify savable type GenericDirectoryFD struct { vfs.FileDescriptionDefaultImpl vfs.DirectoryFileDescriptionDefaultImpl @@ -68,7 +74,7 @@ type GenericDirectoryFD struct { children *OrderedChildren // mu protects the fields below. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // off is the current directory offset. Protected by "mu". off int64 @@ -76,12 +82,12 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} if err := fd.Init(children, locks, opts, fdOpts); err != nil { return nil, err } - if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, opts.Flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return fd, nil @@ -195,8 +201,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // these. childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { - inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(ctx, fd.filesystem(), opts) + stat, err := it.Dentry.inode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 0e3011689..5cc1c4281 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -37,8 +37,7 @@ import ( // * !rp.Done(). // // Postcondition: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) { - d := vfsd.Impl().(*Dentry) +func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, mayFollowSymlinks bool) (*Dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR } @@ -55,20 +54,20 @@ afterSymlink: // calls d_revalidate(), but walk_component() => handle_dots() does not. if name == "." { rp.Advance() - return vfsd, nil + return d, nil } if name == ".." { - if isRoot, err := rp.CheckRoot(ctx, vfsd); err != nil { + if isRoot, err := rp.CheckRoot(ctx, d.VFSDentry()); err != nil { return nil, err } else if isRoot || d.parent == nil { rp.Advance() - return vfsd, nil + return d, nil } - if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + if err := rp.CheckMount(ctx, d.parent.VFSDentry()); err != nil { return nil, err } rp.Advance() - return &d.parent.vfsd, nil + return d.parent, nil } if len(name) > linux.NAME_MAX { return nil, syserror.ENAMETOOLONG @@ -79,7 +78,7 @@ afterSymlink: if err != nil { return nil, err } - if err := rp.CheckMount(ctx, &next.vfsd); err != nil { + if err := rp.CheckMount(ctx, next.VFSDentry()); err != nil { return nil, err } // Resolve any symlink at current path component. @@ -102,7 +101,7 @@ afterSymlink: goto afterSymlink } rp.Advance() - return &next.vfsd, nil + return next, nil } // revalidateChildLocked must be called after a call to parent.vfsd.Child(name) @@ -122,25 +121,21 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir if !child.inode.Valid(ctx) { delete(parent.children, name) vfsObj.InvalidateDentry(ctx, &child.vfsd) - fs.deferDecRef(&child.vfsd) // Reference from Lookup. + fs.deferDecRef(child) // Reference from Lookup. child = nil } } if child == nil { - // Dentry isn't cached; it either doesn't exist or failed - // revalidation. Attempt to resolve it via Lookup. - // - // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return - // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes - // that all dentries in the filesystem are (kernfs.)Dentry and performs - // vfs.DentryImpl casts accordingly. - childVFSD, err := parent.inode.Lookup(ctx, name) + // Dentry isn't cached; it either doesn't exist or failed revalidation. + // Attempt to resolve it via Lookup. + c, err := parent.inode.Lookup(ctx, name) if err != nil { return nil, err } - // Reference on childVFSD dropped by a corresponding Valid. - child = childVFSD.Impl().(*Dentry) - parent.insertChildLocked(name, child) + // Reference on c (provided by Lookup) will be dropped when the dentry + // fails validation. + parent.InsertChildLocked(name, c) + child = c } return child, nil } @@ -153,20 +148,19 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // Preconditions: Filesystem.mu must be locked for at least reading. // // Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() +func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) { + d := rp.Start().Impl().(*Dentry) for !rp.Done() { var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */) + d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */) if err != nil { - return nil, nil, err + return nil, err } } - d := vfsd.Impl().(*Dentry) if rp.MustBeDir() && !d.isDir() { - return nil, nil, syserror.ENOTDIR + return nil, syserror.ENOTDIR } - return vfsd, d.inode, nil + return d, nil } // walkParentDirLocked resolves all but the last path component of rp to an @@ -181,20 +175,19 @@ func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingP // * !rp.Done(). // // Postconditions: Caller must call fs.processDeferredDecRefs*. -func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { - vfsd := rp.Start() +func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*Dentry, error) { + d := rp.Start().Impl().(*Dentry) for !rp.Final() { var err error - vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */) + d, err = fs.stepExistingLocked(ctx, rp, d, true /* mayFollowSymlinks */) if err != nil { - return nil, nil, err + return nil, err } } - d := vfsd.Impl().(*Dentry) if !d.isDir() { - return nil, nil, syserror.ENOTDIR + return nil, syserror.ENOTDIR } - return vfsd, d.inode, nil + return d, nil } // checkCreateLocked checks that a file named rp.Component() may be created in @@ -202,10 +195,9 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // // Preconditions: // * Filesystem.mu must be locked for at least reading. -// * parentInode == parentVFSD.Impl().(*Dentry).Inode. // * isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { +func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return "", err } pc := rp.Component() @@ -215,11 +207,10 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v if len(pc) > linux.NAME_MAX { return "", syserror.ENAMETOOLONG } - // FIXME(gvisor.dev/issue/1193): Data race due to not holding dirMu. - if _, ok := parentVFSD.Impl().(*Dentry).children[pc]; ok { + if _, ok := parent.children[pc]; ok { return "", syserror.EEXIST } - if parentVFSD.IsDead() { + if parent.VFSDentry().IsDead() { return "", syserror.ENOENT } return pc, nil @@ -228,8 +219,8 @@ func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *v // checkDeleteLocked checks that the file represented by vfsd may be deleted. // // Preconditions: Filesystem.mu must be locked for at least reading. -func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error { - parent := vfsd.Impl().(*Dentry).parent +func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) error { + parent := d.parent if parent == nil { return syserror.EBUSY } @@ -258,11 +249,11 @@ func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { return err } - return inode.CheckPermissions(ctx, creds, ats) + return d.inode.CheckPermissions(ctx, creds, ats) } // GetDentryAt implements vfs.FilesystemImpl.GetDentryAt. @@ -270,20 +261,20 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { return nil, err } if opts.CheckSearchable { - d := vfsd.Impl().(*Dentry) if !d.isDir() { return nil, syserror.ENOTDIR } - if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { return nil, err } } + vfsd := d.VFSDentry() vfsd.IncRef() // Ownership transferred to caller. return vfsd, nil } @@ -293,12 +284,12 @@ func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) defer fs.mu.RUnlock() - vfsd, _, err := fs.walkParentDirLocked(ctx, rp) + d, err := fs.walkParentDirLocked(ctx, rp) if err != nil { return nil, err } - vfsd.IncRef() // Ownership transferred to caller. - return vfsd, nil + d.IncRef() // Ownership transferred to caller. + return d.VFSDentry(), nil } // LinkAt implements vfs.FilesystemImpl.LinkAt. @@ -308,12 +299,15 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + parent, err := fs.walkParentDirLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -330,11 +324,11 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EPERM } - childVFSD, err := parentInode.NewLink(ctx, pc, d.inode) + child, err := parent.inode.NewLink(ctx, pc, d.inode) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -345,12 +339,15 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + parent, err := fs.walkParentDirLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -358,11 +355,14 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - childVFSD, err := parentInode.NewDir(ctx, pc, opts) + child, err := parent.inode.NewDir(ctx, pc, opts) if err != nil { - return err + if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + return err + } + child = newSyntheticDirectory(rp.Credentials(), opts.Mode) } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -373,12 +373,15 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + parent, err := fs.walkParentDirLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -386,11 +389,11 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - newVFSD, err := parentInode.NewNode(ctx, pc, opts) + newD, err := parent.inode.NewNode(ctx, pc, opts) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, newVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, newD) return nil } @@ -406,28 +409,27 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Do not create new file. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) if err != nil { fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) return nil, err } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) return nil, err } - inode.IncRef() - defer inode.DecRef(ctx) + d.inode.IncRef() + defer d.inode.DecRef(ctx) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) - return inode.Open(ctx, rp, vfsd, opts) + return d.inode.Open(ctx, rp, d, opts) } // May create new file. mustCreate := opts.Flags&linux.O_EXCL != 0 - vfsd := rp.Start() - inode := vfsd.Impl().(*Dentry).inode + d := rp.Start().Impl().(*Dentry) fs.mu.Lock() unlocked := false unlock := func() { @@ -444,22 +446,22 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err } - inode.IncRef() - defer inode.DecRef(ctx) + d.inode.IncRef() + defer d.inode.DecRef(ctx) unlock() - return inode.Open(ctx, rp, vfsd, opts) + return d.inode.Open(ctx, rp, d, opts) } afterTrailingSymlink: - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + parent, err := fs.walkParentDirLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return nil, err } // Check for search permission in the parent directory. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil { return nil, err } // Reject attempts to open directories with O_CREAT. @@ -474,10 +476,10 @@ afterTrailingSymlink: return nil, syserror.ENAMETOOLONG } // Determine whether or not we need to create a file. - childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */) + child, err := fs.stepExistingLocked(ctx, rp, parent, false /* mayFollowSymlinks */) if err == syserror.ENOENT { // Already checked for searchability above; now check for writability. - if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { + if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -485,16 +487,18 @@ afterTrailingSymlink: } defer rp.Mount().EndWrite() // Create and open the child. - childVFSD, err = parentInode.NewFile(ctx, pc, opts) + child, err := parent.inode.NewFile(ctx, pc, opts) if err != nil { return nil, err } - child := childVFSD.Impl().(*Dentry) - parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + // FIXME(gvisor.dev/issue/1193): Race between checking existence with + // fs.stepExistingLocked and parent.InsertChild. If possible, we should hold + // dirMu from one to the other. + parent.InsertChild(pc, child) child.inode.IncRef() defer child.inode.DecRef(ctx) unlock() - return child.inode.Open(ctx, rp, childVFSD, opts) + return child.inode.Open(ctx, rp, child, opts) } if err != nil { return nil, err @@ -503,7 +507,6 @@ afterTrailingSymlink: if mustCreate { return nil, syserror.EEXIST } - child := childVFSD.Impl().(*Dentry) if rp.ShouldFollowSymlink() && child.isSymlink() { targetVD, targetPathname, err := child.inode.Getlink(ctx, rp.Mount()) if err != nil { @@ -530,22 +533,22 @@ afterTrailingSymlink: child.inode.IncRef() defer child.inode.DecRef(ctx) unlock() - return child.inode.Open(ctx, rp, &child.vfsd, opts) + return child.inode.Open(ctx, rp, child, opts) } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { fs.mu.RLock() - d, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { return "", err } - if !d.Impl().(*Dentry).isSymlink() { + if !d.isSymlink() { return "", syserror.EINVAL } - return inode.Readlink(ctx) + return d.inode.Readlink(ctx, rp.Mount()) } // RenameAt implements vfs.FilesystemImpl.RenameAt. @@ -562,11 +565,10 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Resolve the destination directory first to verify that it's on this // Mount. - dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp) + dstDir, err := fs.walkParentDirLocked(ctx, rp) if err != nil { return err } - dstDir := dstDirVFSD.Impl().(*Dentry) mnt := rp.Mount() if mnt != oldParentVD.Mount() { return syserror.EXDEV @@ -584,16 +586,15 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } - srcVFSD := &src.vfsd // Can we remove the src dentry? - if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil { + if err := checkDeleteLocked(ctx, rp, src); err != nil { return err } // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode) + pc, err := checkCreateLocked(ctx, rp, dstDir) switch err { case nil: // Ok, continue with rename as replacement. @@ -604,14 +605,14 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } dst = dstDir.children[pc] if dst == nil { - panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD)) + panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDir)) } default: return err } var dstVFSD *vfs.Dentry if dst != nil { - dstVFSD = &dst.vfsd + dstVFSD = dst.VFSDentry() } mntns := vfs.MountNamespaceFromContext(ctx) @@ -627,17 +628,18 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa defer dstDir.dirMu.Unlock() } + srcVFSD := src.VFSDentry() if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil { return err } - replaced, err := srcDir.inode.Rename(ctx, src.name, pc, srcVFSD, dstDirVFSD) + replaced, err := srcDir.inode.Rename(ctx, src.name, pc, src, dstDir) if err != nil { virtfs.AbortRenameDentry(srcVFSD, dstVFSD) return err } delete(srcDir.children, src.name) if srcDir != dstDir { - fs.deferDecRef(srcDirVFSD) + fs.deferDecRef(srcDir) dstDir.IncRef() } src.parent = dstDir @@ -646,7 +648,11 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa dstDir.children = make(map[string]*Dentry) } dstDir.children[pc] = src - virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaced) + var replaceVFSD *vfs.Dentry + if replaced != nil { + replaceVFSD = replaced.VFSDentry() + } + virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD) return nil } @@ -654,7 +660,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - vfsd, inode, err := fs.walkExistingLocked(ctx, rp) + + d, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err @@ -663,14 +670,13 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { + if err := checkDeleteLocked(ctx, rp, d); err != nil { return err } - d := vfsd.Impl().(*Dentry) if !d.isDir() { return syserror.ENOTDIR } - if inode.HasChildren() { + if d.inode.HasChildren() { return syserror.ENOTEMPTY } virtfs := rp.VirtualFilesystem() @@ -680,10 +686,12 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + vfsd := d.VFSDentry() if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + + if err := parentDentry.inode.RmDir(ctx, d.name, d); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } @@ -694,7 +702,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { @@ -703,31 +711,31 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts if opts.Stat.Mask == 0 { return nil } - return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts) + return d.inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts) } // StatAt implements vfs.FilesystemImpl.StatAt. func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statx{}, err } - return inode.Stat(ctx, fs.VFSFilesystem(), opts) + return d.inode.Stat(ctx, fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statfs{}, err } - return inode.StatFS(ctx, fs.VFSFilesystem()) + return d.inode.StatFS(ctx, fs.VFSFilesystem()) } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. @@ -737,12 +745,15 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ } fs.mu.Lock() defer fs.mu.Unlock() - parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp) + parent, err := fs.walkParentDirLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err } - pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode) + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + + pc, err := checkCreateLocked(ctx, rp, parent) if err != nil { return err } @@ -750,11 +761,11 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ return err } defer rp.Mount().EndWrite() - childVFSD, err := parentInode.NewSymlink(ctx, pc, target) + child, err := parent.inode.NewSymlink(ctx, pc, target) if err != nil { return err } - parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) + parent.InsertChildLocked(pc, child) return nil } @@ -762,7 +773,8 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() - vfsd, _, err := fs.walkExistingLocked(ctx, rp) + + d, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { return err @@ -771,10 +783,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } defer rp.Mount().EndWrite() - if err := checkDeleteLocked(ctx, rp, vfsd); err != nil { + if err := checkDeleteLocked(ctx, rp, d); err != nil { return err } - d := vfsd.Impl().(*Dentry) if d.isDir() { return syserror.EISDIR } @@ -784,10 +795,11 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error defer parentDentry.dirMu.Unlock() mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + vfsd := d.VFSDentry() if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + if err := parentDentry.inode.Unlink(ctx, d.name, d); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } @@ -795,25 +807,25 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() - _, inode, err := fs.walkExistingLocked(ctx, rp) + d, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { return nil, err } - if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { + if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil { return nil, err } return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { @@ -823,10 +835,10 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { @@ -836,10 +848,10 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { @@ -849,10 +861,10 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index c0b863ba4..49210e748 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -31,6 +31,8 @@ import ( // count for inodes, performing no extra actions when references are obtained or // released. This is suitable for simple file inodes that don't reference any // resources. +// +// +stateify savable type InodeNoopRefCount struct { } @@ -50,30 +52,32 @@ func (InodeNoopRefCount) TryIncRef() bool { // InodeDirectoryNoNewChildren partially implements the Inode interface. // InodeDirectoryNoNewChildren represents a directory inode which does not // support creation of new children. +// +// +stateify savable type InodeDirectoryNoNewChildren struct{} // NewFile implements Inode.NewFile. -func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) { return nil, syserror.EPERM } // NewDir implements Inode.NewDir. -func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) { return nil, syserror.EPERM } // NewLink implements Inode.NewLink. -func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*Dentry, error) { return nil, syserror.EPERM } // NewSymlink implements Inode.NewSymlink. -func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*Dentry, error) { return nil, syserror.EPERM } // NewNode implements Inode.NewNode. -func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) { return nil, syserror.EPERM } @@ -81,6 +85,8 @@ func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOpt // inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not // represent directories can embed this to provide no-op implementations for // directory-related functions. +// +// +stateify savable type InodeNotDirectory struct { } @@ -90,47 +96,47 @@ func (InodeNotDirectory) HasChildren() bool { } // NewFile implements Inode.NewFile. -func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*Dentry, error) { panic("NewFile called on non-directory inode") } // NewDir implements Inode.NewDir. -func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*Dentry, error) { panic("NewDir called on non-directory inode") } // NewLink implements Inode.NewLinkink. -func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*Dentry, error) { panic("NewLink called on non-directory inode") } // NewSymlink implements Inode.NewSymlink. -func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*Dentry, error) { panic("NewSymlink called on non-directory inode") } // NewNode implements Inode.NewNode. -func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*Dentry, error) { panic("NewNode called on non-directory inode") } // Unlink implements Inode.Unlink. -func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) Unlink(context.Context, string, *Dentry) error { panic("Unlink called on non-directory inode") } // RmDir implements Inode.RmDir. -func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error { +func (InodeNotDirectory) RmDir(context.Context, string, *Dentry) error { panic("RmDir called on non-directory inode") } // Rename implements Inode.Rename. -func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) { +func (InodeNotDirectory) Rename(context.Context, string, string, *Dentry, *Dentry) (*Dentry, error) { panic("Rename called on non-directory inode") } // Lookup implements Inode.Lookup. -func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*Dentry, error) { panic("Lookup called on non-directory inode") } @@ -149,10 +155,12 @@ func (InodeNotDirectory) Valid(context.Context) bool { // dymanic entries (i.e. entries that are not "hashed" into the // vfs.Dentry.children) can embed this to provide no-op implementations for // functions related to dynamic entries. +// +// +stateify savable type InodeNoDynamicLookup struct{} // Lookup implements Inode.Lookup. -func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*Dentry, error) { return nil, syserror.ENOENT } @@ -169,10 +177,12 @@ func (InodeNoDynamicLookup) Valid(ctx context.Context) bool { // InodeNotSymlink partially implements the Inode interface, specifically the // inodeSymlink sub interface. All inodes that are not symlinks may embed this // to return the appropriate errors from symlink-related functions. +// +// +stateify savable type InodeNotSymlink struct{} // Readlink implements Inode.Readlink. -func (InodeNotSymlink) Readlink(context.Context) (string, error) { +func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) { return "", syserror.EINVAL } @@ -186,6 +196,8 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // inode attributes. // // Must be initialized by Init prior to first use. +// +// +stateify savable type InodeAttrs struct { devMajor uint32 devMinor uint32 @@ -256,12 +268,29 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return a.SetInodeStat(ctx, fs, creds, opts) +} + +// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. +// This function can be used by other kernfs-based filesystem implementation to +// sets the unexported attributes into kernfs.InodeAttrs. +func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. Setting the size is often + // allowed by kernfs files but does not do anything. If some other behavior is + // needed, the embedder should consider extending SetStat. + // + // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { return syserror.EPERM } + if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { + return syserror.EISDIR + } if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -284,13 +313,6 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut atomic.StoreUint32(&a.gid, stat.GID) } - // Note that not all fields are modifiable. For example, the file type and - // inode numbers are immutable after node creation. - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - // Also, STATX_SIZE will need some special handling, because read-only static - // files should return EIO for truncate operations. - return nil } @@ -320,13 +342,16 @@ func (a *InodeAttrs) DecLinks() { } } +// +stateify savable type slot struct { Name string - Dentry *vfs.Dentry + Dentry *Dentry slotEntry } // OrderedChildrenOptions contains initialization options for OrderedChildren. +// +// +stateify savable type OrderedChildrenOptions struct { // Writable indicates whether vfs.FilesystemImpl methods implemented by // OrderedChildren may modify the tracked children. This applies to @@ -342,12 +367,14 @@ type OrderedChildrenOptions struct { // directories. // // Must be initialize with Init before first use. +// +// +stateify savable type OrderedChildren struct { // Can children be modified by user syscalls? It set to false, interface // methods that would modify the children return EPERM. Immutable. writable bool - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` order slotList set map[string]*slot } @@ -380,7 +407,7 @@ func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint3 if child.isDir() { links++ } - if err := o.Insert(name, child.VFSDentry()); err != nil { + if err := o.Insert(name, child); err != nil { panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d)) } d.InsertChild(name, child) @@ -397,7 +424,7 @@ func (o *OrderedChildren) HasChildren() bool { // Insert inserts child into o. This ignores the writability of o, as this is // not part of the vfs.FilesystemImpl interface, and is a lower-level operation. -func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error { +func (o *OrderedChildren) Insert(name string, child *Dentry) error { o.mu.Lock() defer o.mu.Unlock() if _, ok := o.set[name]; ok { @@ -421,10 +448,10 @@ func (o *OrderedChildren) removeLocked(name string) { } // Precondition: caller must hold o.mu for writing. -func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry { +func (o *OrderedChildren) replaceChildLocked(name string, new *Dentry) *Dentry { if s, ok := o.set[name]; ok { // Existing slot with given name, simply replace the dentry. - var old *vfs.Dentry + var old *Dentry old, s.Dentry = s.Dentry, new return old } @@ -440,7 +467,7 @@ func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs. } // Precondition: caller must hold o.mu for reading or writing. -func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error { +func (o *OrderedChildren) checkExistingLocked(name string, child *Dentry) error { s, ok := o.set[name] if !ok { return syserror.ENOENT @@ -452,7 +479,7 @@ func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) er } // Unlink implements Inode.Unlink. -func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { +func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *Dentry) error { if !o.writable { return syserror.EPERM } @@ -468,12 +495,13 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De } // Rmdir implements Inode.Rmdir. -func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { +func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *Dentry) error { // We're not responsible for checking that child is a directory, that it's // empty, or updating any link counts; so this is the same as unlink. return o.Unlink(ctx, name, child) } +// +stateify savable type renameAcrossDifferentImplementationsError struct{} func (renameAcrossDifferentImplementationsError) Error() string { @@ -489,8 +517,8 @@ func (renameAcrossDifferentImplementationsError) Error() string { // that will support Rename. // // Postcondition: reference on any replaced dentry transferred to caller. -func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) { - dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren) +func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (*Dentry, error) { + dst, ok := dstDir.inode.(interface{}).(*OrderedChildren) if !ok { return nil, renameAcrossDifferentImplementationsError{} } @@ -532,12 +560,14 @@ func (o *OrderedChildren) nthLocked(i int64) *slot { } // InodeSymlink partially implements Inode interface for symlinks. +// +// +stateify savable type InodeSymlink struct { InodeNotDirectory } // Open implements Inode.Open. -func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ELOOP } @@ -564,6 +594,7 @@ var _ Inode = (*StaticDirectory)(nil) func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry { inode := &StaticDirectory{} inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.EnableLeakCheck() dentry := &Dentry{} dentry.Init(inode) @@ -584,9 +615,9 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint3 s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } -// Open implements kernfs.Inode. -func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts, s.fdOpts) +// Open implements kernfs.Inode.Open. +func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := NewGenericDirectoryFD(rp.Mount(), d, &s.OrderedChildren, &s.locks, &opts, s.fdOpts) if err != nil { return nil, err } @@ -598,21 +629,25 @@ func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credenti return syserror.EPERM } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (s *StaticDirectory) DecRef(context.Context) { s.StaticDirectoryRefs.DecRef(s.Destroy) } // AlwaysValid partially implements kernfs.inodeDynamicLookup. +// +// +stateify savable type AlwaysValid struct{} -// Valid implements kernfs.inodeDynamicLookup. +// Valid implements kernfs.inodeDynamicLookup.Valid. func (*AlwaysValid) Valid(context.Context) bool { return true } // InodeNoStatFS partially implements the Inode interface, where the client // filesystem doesn't support statfs(2). +// +// +stateify savable type InodeNoStatFS struct{} // StatFS implements Inode.StatFS. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 88fcd54aa..6d3d79333 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -29,7 +29,7 @@ // // Reference Model: // -// Kernfs dentries represents named pointers to inodes. Dentries and inode have +// Kernfs dentries represents named pointers to inodes. Dentries and inodes have // independent lifetimes and reference counts. A child dentry unconditionally // holds a reference on its parent directory's dentry. A dentry also holds a // reference on the inode it points to. Multiple dentries can point to the same @@ -60,20 +60,23 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" ) // Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory // filesystem. Concrete implementations are expected to embed this in their own // Filesystem type. +// +// +stateify savable type Filesystem struct { vfsfs vfs.Filesystem - droppedDentriesMu sync.Mutex + droppedDentriesMu sync.Mutex `state:"nosave"` // droppedDentries is a list of dentries waiting to be DecRef()ed. This is // used to defer dentry destruction until mu can be acquired for // writing. Protected by droppedDentriesMu. - droppedDentries []*vfs.Dentry + droppedDentries []*Dentry // mu synchronizes the lifetime of Dentries on this filesystem. Holding it // for reading guarantees continued existence of any resolved dentries, but @@ -96,7 +99,7 @@ type Filesystem struct { // defer fs.mu.RUnlock() // ... // fs.deferDecRef(dentry) - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. @@ -107,7 +110,7 @@ type Filesystem struct { // processDeferredDecRefs{,Locked}. See comment on Filesystem.mu. // // Precondition: d must not already be pending destruction. -func (fs *Filesystem) deferDecRef(d *vfs.Dentry) { +func (fs *Filesystem) deferDecRef(d *Dentry) { fs.droppedDentriesMu.Lock() fs.droppedDentries = append(fs.droppedDentries, d) fs.droppedDentriesMu.Unlock() @@ -159,6 +162,8 @@ const ( // to, and child dentries hold a reference on their parent. // // Must be initialized by Init prior to first use. +// +// +stateify savable type Dentry struct { DentryRefs @@ -172,7 +177,11 @@ type Dentry struct { name string // dirMu protects children and the names of child Dentries. - dirMu sync.Mutex + // + // Note that holding fs.mu for writing is not sufficient; + // revalidateChildLocked(), which is a very hot path, may modify children with + // fs.mu acquired for reading only. + dirMu sync.Mutex `state:"nosave"` children map[string]*Dentry inode Inode @@ -239,24 +248,25 @@ func (d *Dentry) Watches() *vfs.Watches { func (d *Dentry) OnZeroWatches(context.Context) {} // InsertChild inserts child into the vfs dentry cache with the given name under -// this dentry. This does not update the directory inode, so calling this on -// its own isn't sufficient to insert a child into a directory. InsertChild -// updates the link count on d if required. +// this dentry. This does not update the directory inode, so calling this on its +// own isn't sufficient to insert a child into a directory. // // Precondition: d must represent a directory inode. func (d *Dentry) InsertChild(name string, child *Dentry) { d.dirMu.Lock() - d.insertChildLocked(name, child) + d.InsertChildLocked(name, child) d.dirMu.Unlock() } -// insertChildLocked is equivalent to InsertChild, with additional +// InsertChildLocked is equivalent to InsertChild, with additional // preconditions. // -// Precondition: d.dirMu must be locked. -func (d *Dentry) insertChildLocked(name string, child *Dentry) { +// Preconditions: +// * d must represent a directory inode. +// * d.dirMu must be locked. +func (d *Dentry) InsertChildLocked(name string, child *Dentry) { if !d.isDir() { - panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) + panic(fmt.Sprintf("InsertChildLocked called on non-directory Dentry: %+v.", d)) } d.IncRef() // DecRef in child's Dentry.destroy. child.parent = d @@ -267,6 +277,36 @@ func (d *Dentry) insertChildLocked(name string, child *Dentry) { d.children[name] = child } +// RemoveChild removes child from the vfs dentry cache. This does not update the +// directory inode or modify the inode to be unlinked. So calling this on its own +// isn't sufficient to remove a child from a directory. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) RemoveChild(name string, child *Dentry) error { + d.dirMu.Lock() + defer d.dirMu.Unlock() + return d.RemoveChildLocked(name, child) +} + +// RemoveChildLocked is equivalent to RemoveChild, with additional +// preconditions. +// +// Precondition: d.dirMu must be locked. +func (d *Dentry) RemoveChildLocked(name string, child *Dentry) error { + if !d.isDir() { + panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d)) + } + c, ok := d.children[name] + if !ok { + return syserror.ENOENT + } + if c != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child)) + } + delete(d.children, name) + return nil +} + // Inode returns the dentry's inode. func (d *Dentry) Inode() Inode { return d.inode @@ -287,7 +327,6 @@ func (d *Dentry) Inode() Inode { // // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. -// - Updating link and reference counts. // // Specific responsibilities of implementations are documented below. type Inode interface { @@ -297,7 +336,8 @@ type Inode interface { inodeRefs // Methods related to node metadata. A generic implementation is provided by - // InodeAttrs. + // InodeAttrs. Note that a concrete filesystem using kernfs is responsible for + // managing link counts. inodeMetadata // Method for inodes that represent symlink. InodeNotSymlink provides a @@ -315,11 +355,11 @@ type Inode interface { // Open creates a file description for the filesystem object represented by // this inode. The returned file description should hold a reference on the - // inode for its lifetime. + // dentry for its lifetime. // // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing // the inode on which Open() is being called. - Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) + Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) // StatFS returns filesystem statistics for the client filesystem. This // corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem @@ -369,30 +409,30 @@ type inodeDirectory interface { HasChildren() bool // NewFile creates a new regular file inode. - NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) + NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) // NewDir creates a new directory inode. - NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) + NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) // NewLink creates a new hardlink to a specified inode in this // directory. Implementations should create a new kernfs Dentry pointing to // target, and update target's link count. - NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) + NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) // NewSymlink creates a new symbolic link inode. - NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) + NewSymlink(ctx context.Context, name, target string) (*Dentry, error) // NewNode creates a new filesystem node for a mknod syscall. - NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) + NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) // Unlink removes a child dentry from this directory inode. - Unlink(ctx context.Context, name string, child *vfs.Dentry) error + Unlink(ctx context.Context, name string, child *Dentry) error // RmDir removes an empty child directory from this directory // inode. Implementations must update the parent directory's link count, // if required. Implementations are not responsible for checking that child // is a directory, checking for an empty directory. - RmDir(ctx context.Context, name string, child *vfs.Dentry) error + RmDir(ctx context.Context, name string, child *Dentry) error // Rename is called on the source directory containing an inode being // renamed. child should point to the resolved child in the source @@ -400,7 +440,7 @@ type inodeDirectory interface { // should return the replaced dentry or nil otherwise. // // Precondition: Caller must serialize concurrent calls to Rename. - Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error) + Rename(ctx context.Context, oldname, newname string, child, dstDir *Dentry) (replaced *Dentry, err error) } type inodeDynamicLookup interface { @@ -418,14 +458,14 @@ type inodeDynamicLookup interface { // // Lookup returns the child with an extra reference and the caller owns this // reference. - Lookup(ctx context.Context, name string) (*vfs.Dentry, error) + Lookup(ctx context.Context, name string) (*Dentry, error) // Valid should return true if this inode is still valid, or needs to // be resolved again by a call to Lookup. Valid(ctx context.Context) bool // IterDirents is used to iterate over dynamically created entries. It invokes - // cb on each entry in the directory represented by the FileDescription. + // cb on each entry in the directory represented by the Inode. // 'offset' is the offset for the entire IterDirents call, which may include // results from the caller (e.g. "." and ".."). 'relOffset' is the offset // inside the entries returned by this IterDirents invocation. In other words, @@ -437,7 +477,7 @@ type inodeDynamicLookup interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. - Readlink(ctx context.Context) (string, error) + Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path // resolution: diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 675587c6b..e413242dc 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -52,7 +52,7 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create testfs root mount: %v", err) } @@ -121,8 +121,8 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod return &dir.dentry } -func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ +func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndStaticEntries, }) if err != nil { @@ -162,8 +162,8 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte return &dir.dentry } -func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndStaticEntries, }) if err != nil { @@ -176,38 +176,36 @@ func (d *dir) DecRef(context.Context) { d.dirRefs.DecRef(d.Destroy) } -func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { +func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*kernfs.Dentry, error) { creds := auth.CredentialsFromContext(ctx) dir := d.fs.newDir(creds, opts.Mode, nil) - dirVFSD := dir.VFSDentry() - if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil { + if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } d.IncLinks(1) - return dirVFSD, nil + return dir, nil } -func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { +func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*kernfs.Dentry, error) { creds := auth.CredentialsFromContext(ctx) f := d.fs.newFile(creds, "") - fVFSD := f.VFSDentry() - if err := d.OrderedChildren.Insert(name, fVFSD); err != nil { + if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } - return fVFSD, nil + return f, nil } -func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) { +func (*dir) NewLink(context.Context, string, kernfs.Inode) (*kernfs.Dentry, error) { return nil, syserror.EPERM } -func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) { +func (*dir) NewSymlink(context.Context, string, string) (*kernfs.Dentry, error) { return nil, syserror.EPERM } -func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) { +func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*kernfs.Dentry, error) { return nil, syserror.EPERM } diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 64731a3e4..58a93eaac 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -24,6 +24,8 @@ import ( // StaticSymlink provides an Inode implementation for symlinks that point to // a immutable target. +// +// +stateify savable type StaticSymlink struct { InodeAttrs InodeNoopRefCount @@ -51,8 +53,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } -// Readlink implements Inode. -func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { +// Readlink implements Inode.Readlink. +func (s *StaticSymlink) Readlink(_ context.Context, _ *vfs.Mount) (string, error) { return s.target, nil } diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go new file mode 100644 index 000000000..ea7f073eb --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -0,0 +1,102 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// syntheticDirectory implements kernfs.Inode for a directory created by +// MkdirAt(ForSyntheticMountpoint=true). +// +// +stateify savable +type syntheticDirectory struct { + InodeAttrs + InodeNoStatFS + InodeNoopRefCount + InodeNoDynamicLookup + InodeNotSymlink + OrderedChildren + + locks vfs.FileLocks +} + +var _ Inode = (*syntheticDirectory)(nil) + +func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *Dentry { + inode := &syntheticDirectory{} + inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + d := &Dentry{} + d.Init(inode) + return d +} + +func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) + } + dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.OrderedChildren.Init(OrderedChildrenOptions{ + Writable: true, + }) +} + +// Open implements Inode.Open. +func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := NewGenericDirectoryFD(rp.Mount(), d, &dir.OrderedChildren, &dir.locks, &opts, GenericDirectoryFDOptions{}) + if err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// NewFile implements Inode.NewFile. +func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*Dentry, error) { + if !opts.ForSyntheticMountpoint { + return nil, syserror.EPERM + } + subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + if err := dir.OrderedChildren.Insert(name, subdird); err != nil { + subdird.DecRef(ctx) + return nil, err + } + return subdird, nil +} + +// NewLink implements Inode.NewLink. +func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*Dentry, error) { + return nil, syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 13735eb05..73b126669 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -81,6 +82,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Start: d.parent.upperVD, Path: fspath.Parse(d.name), } + // Used during copy-up of memory-mapped regular files. + var mmapOpts *memmap.MMapOpts cleanupUndoCopyUp := func() { var err error if ftype == linux.S_IFDIR { @@ -89,7 +92,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop) } if err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)) + } + if d.upperVD.Ok() { + d.upperVD.DecRef(ctx) + d.upperVD = vfs.VirtualDentry{} } } switch ftype { @@ -132,6 +139,25 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { break } } + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if d.wrappedMappable != nil { + // We may have memory mappings of the file on the lower layer. + // Switch to mapping the file on the upper layer instead. + mmapOpts = &memmap.MMapOpts{ + Perms: usermem.ReadWrite, + MaxPerms: usermem.ReadWrite, + } + if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil { + cleanupUndoCopyUp() + return err + } + if mmapOpts.MappingIdentity != nil { + mmapOpts.MappingIdentity.DecRef(ctx) + } + // Don't actually switch Mappables until the end of copy-up; see + // below for why. + } if err := newFD.SetStat(ctx, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_UID | linux.STATX_GID, @@ -234,7 +260,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { panic(fmt.Sprintf("unexpected file type %o", ftype)) } - // TODO(gvisor.dev/issue/1199): copy up xattrs + if err := d.copyXattrsLocked(ctx); err != nil { + cleanupUndoCopyUp() + return err + } // Update the dentry's device and inode numbers (except for directories, // for which these remain overlay-assigned). @@ -246,14 +275,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Mask: linux.STATX_INO, }) if err != nil { - d.upperVD.DecRef(ctx) - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return err } if upperStat.Mask&linux.STATX_INO == 0 { - d.upperVD.DecRef(ctx) - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return syserror.EREMOTE } @@ -262,6 +287,135 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { atomic.StoreUint64(&d.ino, upperStat.Ino) } + if mmapOpts != nil && mmapOpts.Mappable != nil { + // Note that if mmapOpts != nil, then d.mapsMu is locked for writing + // (from the S_IFREG path above). + + // Propagate mappings of d to the new Mappable. Remember which mappings + // we added so we can remove them on failure. + upperMappable := mmapOpts.Mappable + allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange) + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + added := make(memmap.MappingsOfRange) + for m := range seg.Value() { + if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil { + for m := range added { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + for mr, mappings := range allAdded { + for m := range mappings { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable) + } + } + return err + } + added[m] = struct{}{} + } + allAdded[seg.Range()] = added + } + + // Switch to the new Mappable. We do this at the end of copy-up + // because: + // + // - We need to switch Mappables (by changing d.wrappedMappable) before + // invalidating Translations from the old Mappable (to pick up + // Translations from the new one). + // + // - We need to lock d.dataMu while changing d.wrappedMappable, but + // must invalidate Translations with d.dataMu unlocked (due to lock + // ordering). + // + // - Consequently, once we unlock d.dataMu, other threads may + // immediately observe the new (copied-up) Mappable, which we want to + // delay until copy-up is guaranteed to succeed. + d.dataMu.Lock() + lowerMappable := d.wrappedMappable + d.wrappedMappable = upperMappable + d.dataMu.Unlock() + d.lowerMappings.InvalidateAll(memmap.InvalidateOpts{}) + + // Remove mappings from the old Mappable. + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + for m := range seg.Value() { + lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + } + d.lowerMappings.RemoveAll() + } + atomic.StoreUint32(&d.copiedUp, 1) return nil } + +// copyXattrsLocked copies a subset of lower's extended attributes to upper. +// Attributes that configure an overlay in the lower are not copied up. +// +// Preconditions: d.copyMu must be locked for writing. +func (d *dentry) copyXattrsLocked(ctx context.Context) error { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + lowerPop := &vfs.PathOperation{Root: d.lowerVDs[0], Start: d.lowerVDs[0]} + upperPop := &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD} + + lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0) + if err != nil { + if err == syserror.EOPNOTSUPP { + // There are no guarantees as to the contents of lowerXattrs. + return nil + } + ctx.Infof("failed to copy up xattrs because ListXattrAt failed: %v", err) + return err + } + + for _, name := range lowerXattrs { + // Do not copy up overlay attributes. + if isOverlayXattr(name) { + continue + } + + value, err := vfsObj.GetXattrAt(ctx, d.fs.creds, lowerPop, &vfs.GetXattrOptions{Name: name, Size: 0}) + if err != nil { + ctx.Infof("failed to copy up xattrs because GetXattrAt failed: %v", err) + return err + } + + if err := vfsObj.SetXattrAt(ctx, d.fs.creds, upperPop, &vfs.SetXattrOptions{Name: name, Value: value}); err != nil { + ctx.Infof("failed to copy up xattrs because SetXattrAt failed: %v", err) + return err + } + } + return nil +} + +// copyUpDescendantsLocked ensures that all descendants of d are copied up. +// +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) error { + dirents, err := d.getDirentsLocked(ctx) + if err != nil { + return err + } + for _, dirent := range dirents { + if dirent.Name == "." || dirent.Name == ".." { + continue + } + child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + if err != nil { + return err + } + if err := child.copyUpLocked(ctx); err != nil { + return err + } + if child.isDir() { + child.dirMu.Lock() + err := child.copyUpDescendantsLocked(ctx, ds) + child.dirMu.Unlock() + if err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index b1b292e83..df4492346 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -100,12 +100,13 @@ func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string return whiteouts, readdirErr } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl - mu sync.Mutex + mu sync.Mutex `state:"nosave"` off int64 dirents []vfs.Dirent } @@ -116,10 +117,12 @@ func (fd *directoryFD) Release(ctx context.Context) { // IterDirents implements vfs.FileDescriptionImpl.IterDirents. func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + d := fd.dentry() + defer d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent) + fd.mu.Lock() defer fd.mu.Unlock() - d := fd.dentry() if fd.dirents == nil { ds, err := d.getDirents(ctx) if err != nil { @@ -143,7 +146,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { defer d.fs.renameMu.RUnlock() d.dirMu.Lock() defer d.dirMu.Unlock() + return d.getDirentsLocked(ctx) +} +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) getDirentsLocked(ctx context.Context) ([]vfs.Dirent, error) { if d.dirents != nil { return d.dirents, nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index e720bfb0b..bd11372d5 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -15,6 +15,8 @@ package overlay import ( + "fmt" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,10 +29,15 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs +// attributes. +// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_PREFIX +const _OVL_XATTR_PREFIX = linux.XATTR_TRUSTED_PREFIX + "overlay." + // _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for // opaque directories. // Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE -const _OVL_XATTR_OPAQUE = linux.XATTR_TRUSTED_PREFIX + "overlay.opaque" +const _OVL_XATTR_OPAQUE = _OVL_XATTR_PREFIX + "opaque" func isWhiteout(stat *linux.Statx) bool { return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0 @@ -205,6 +212,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str lookupErr = err return false } + defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) if !existsOnAnyLayer { @@ -243,6 +251,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } // Update child to include this layer. + childVD.IncRef() if isUpper { child.upperVD = childVD child.copiedUp = 1 @@ -267,10 +276,10 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str // Directories are merged with directories from lower layers if they // are not explicitly opaque. - opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{ + opaqueVal, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ Root: childVD, Start: childVD, - }, &vfs.GetxattrOptions{ + }, &vfs.GetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Size: 1, }) @@ -490,7 +499,13 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil { return err } + parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) return nil } @@ -504,7 +519,7 @@ func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFil func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) { if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)) } } @@ -616,12 +631,13 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop) } return err } + old.watches.Notify(ctx, "", linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent, false /* unlinked */) return nil }) } @@ -655,7 +671,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -665,12 +681,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // There may be directories on lower layers (previously hidden by // the whiteout) that the new directory should not be merged with. // Mark it opaque to prevent merging. - if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{ + if err := vfsObj.SetXattrAt(ctx, fs.creds, &pop, &vfs.SetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Value: "y", }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)) } else { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -714,7 +730,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -743,6 +759,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf start := rp.Start().Impl().(*dentry) if rp.Done() { + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } if mustCreate { return nil, syserror.EEXIST } @@ -766,6 +785,10 @@ afterTrailingSymlink: if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } + // Reject attempts to open directories with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } // Determine whether or not we need to create a file. parent.dirMu.Lock() child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -774,12 +797,11 @@ afterTrailingSymlink: parent.dirMu.Unlock() return fd, err } + parent.dirMu.Unlock() if err != nil { - parent.dirMu.Unlock() return nil, err } // Open existing child or follow symlink. - parent.dirMu.Unlock() if mustCreate { return nil, syserror.EEXIST } @@ -794,6 +816,9 @@ afterTrailingSymlink: start = parent goto afterTrailingSymlink } + if rp.MustBeDir() && !child.isDir() { + return nil, syserror.ENOTDIR + } if mayWrite { if err := child.copyUpLocked(ctx); err != nil { return nil, err @@ -925,7 +950,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -936,7 +961,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving child, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -957,6 +982,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving // just can't open it anymore for some reason. return nil, err } + parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */) return &fd.vfsfd, nil } @@ -1002,9 +1028,224 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } defer mnt.EndWrite() - // FIXME(gvisor.dev/issue/1199): Actually implement rename. - _ = newParent - return syserror.EXDEV + oldParent := oldParentVD.Dentry().Impl().(*dentry) + creds := rp.Credentials() + if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a + // directory, we need to check for write permission on it. + oldParent.dirMu.Lock() + defer oldParent.dirMu.Unlock() + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + if err != nil { + return err + } + if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil { + return err + } + if renamed.isDir() { + if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { + return syserror.EINVAL + } + if oldParent != newParent { + if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + } + } else { + if opts.MustBeDir || rp.MustBeDir() { + return syserror.ENOTDIR + } + } + + if oldParent != newParent { + if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + newParent.dirMu.Lock() + defer newParent.dirMu.Unlock() + } + if newParent.vfsd.IsDead() { + return syserror.ENOENT + } + replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) + if err != nil { + return err + } + var ( + replaced *dentry + replacedVFSD *vfs.Dentry + whiteouts map[string]bool + ) + if replacedLayer.existsInOverlay() { + replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil { + return err + } + replacedVFSD = &replaced.vfsd + if replaced.isDir() { + if !renamed.isDir() { + return syserror.EISDIR + } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } + replaced.dirMu.Lock() + defer replaced.dirMu.Unlock() + whiteouts, err = replaced.collectWhiteoutsForRmdirLocked(ctx) + if err != nil { + return err + } + } else { + if rp.MustBeDir() || renamed.isDir() { + return syserror.ENOTDIR + } + } + } + + if oldParent == newParent && oldName == newName { + return nil + } + + // renamed and oldParent need to be copied-up before they're renamed on the + // upper layer. + if err := renamed.copyUpLocked(ctx); err != nil { + return err + } + // If renamed is a directory, all of its descendants need to be copied-up + // before they're renamed on the upper layer. + if renamed.isDir() { + if err := renamed.copyUpDescendantsLocked(ctx, &ds); err != nil { + return err + } + } + // newParent must be copied-up before it can contain renamed on the upper + // layer. + if err := newParent.copyUpLocked(ctx); err != nil { + return err + } + // If replaced exists, it doesn't need to be copied-up, but we do need to + // serialize with copy-up. Holding renameMu for writing should be + // sufficient, but out of an abundance of caution... + if replaced != nil { + replaced.copyMu.RLock() + defer replaced.copyMu.RUnlock() + } + + vfsObj := rp.VirtualFilesystem() + mntns := vfs.MountNamespaceFromContext(ctx) + defer mntns.DecRef(ctx) + if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { + return err + } + + newpop := vfs.PathOperation{ + Root: newParent.upperVD, + Start: newParent.upperVD, + Path: fspath.Parse(newName), + } + + needRecreateWhiteouts := false + cleanupRecreateWhiteouts := func() { + if !needRecreateWhiteouts { + return + } + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil && err != syserror.EEXIST { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err)) + } + } + } + if renamed.isDir() { + if replacedLayer == lookupLayerUpper { + // Remove whiteouts from the directory being replaced. + needRecreateWhiteouts = true + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } else if replacedLayer == lookupLayerUpperWhiteout { + // We need to explicitly remove the whiteout since otherwise rename + // on the upper layer will fail with ENOTDIR. + if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil { + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } + + // Essentially no gVisor filesystem supports RENAME_WHITEOUT, so just do a + // regular rename and create the whiteout at the origin manually. Unlike + // RENAME_WHITEOUT, this isn't atomic with respect to other users of the + // upper filesystem, but this is already the case for virtually all other + // overlay filesystem operations too. + oldpop := vfs.PathOperation{ + Root: oldParent.upperVD, + Start: oldParent.upperVD, + Path: fspath.Parse(oldName), + } + if err := vfsObj.RenameAt(ctx, creds, &oldpop, &newpop, &opts); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + + // Below this point, the renamed dentry is now at newpop, and anything we + // replaced is gone forever. Commit the rename, update the overlay + // filesystem tree, and abandon attempts to recover from errors. + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) + delete(oldParent.children, oldName) + if replaced != nil { + ds = appendDentry(ds, replaced) + } + if oldParent != newParent { + newParent.dirents = nil + // This can't drop the last reference on oldParent because one is held + // by oldParentVD, so lock recursion is impossible. + oldParent.DecRef(ctx) + ds = appendDentry(ds, oldParent) + newParent.IncRef() + renamed.parent = newParent + } + renamed.name = newName + if newParent.children == nil { + newParent.children = make(map[string]*dentry) + } + newParent.children[newName] = renamed + oldParent.dirents = nil + + if err := fs.createWhiteout(ctx, vfsObj, &oldpop); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout at origin after RenameAt: %v", err)) + } + if renamed.isDir() { + if err := vfsObj.SetXattrAt(ctx, fs.creds, &newpop, &vfs.SetXattrOptions{ + Name: _OVL_XATTR_OPAQUE, + Value: "y", + }); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to make renamed directory opaque: %v", err)) + } + } + + vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir()) + return nil } // RmdirAt implements vfs.FilesystemImpl.RmdirAt. @@ -1083,7 +1324,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error Start: child.upperVD, Path: fspath.Parse(whiteoutName), }); err != nil && err != syserror.EEXIST { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)) } } } @@ -1113,15 +1354,14 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Don't attempt to recover from this: the original directory is // already gone, so any dentries representing it are invalid, and // creating a new directory won't undo that. - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err) - vfsObj.AbortDeleteDentry(&child.vfsd) - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)) } vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) parent.dirents = nil + parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0 /* cookie */, vfs.InodeEvent, true /* unlinked */) return nil } @@ -1129,12 +1369,25 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) return err } + err = d.setStatLocked(ctx, rp, opts) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + if err != nil { + return err + } + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ctx, ev, 0 /* cookie */, vfs.InodeEvent) + } + return nil +} +// Precondition: d.fs.renameMu must be held for reading. +func (d *dentry) setStatLocked(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { mode := linux.FileMode(atomic.LoadUint32(&d.mode)) if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err @@ -1229,7 +1482,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -1322,70 +1575,175 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } } if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err) - if child != nil { - vfsObj.AbortDeleteDentry(&child.vfsd) - } - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)) } + var cw *vfs.Watches if child != nil { vfsObj.CommitDeleteDentry(ctx, &child.vfsd) delete(parent.children, name) ds = appendDentry(ds, child) + cw = &child.watches } + vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name) parent.dirents = nil return nil } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// isOverlayXattr returns whether the given extended attribute configures the +// overlay. +func isOverlayXattr(name string) bool { + return strings.HasPrefix(name, _OVL_XATTR_PREFIX) +} + +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err } - // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr, - // but not any other xattr syscalls. For now we just reject all of them. - return nil, syserror.ENOTSUP + + return fs.listXattr(ctx, d, size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +func (fs *filesystem) listXattr(ctx context.Context, d *dentry, size uint64) ([]string, error) { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + names, err := vfsObj.ListXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, size) + if err != nil { + return nil, err + } + + // Filter out all overlay attributes. + n := 0 + for _, name := range names { + if !isOverlayXattr(name) { + names[n] = name + n++ + } + } + return names[:n], err +} + +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err } - return "", syserror.ENOTSUP + + return fs.getXattr(ctx, d, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +func (fs *filesystem) getXattr(ctx context.Context, d *dentry, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { + return "", err + } + + // Return EOPNOTSUPP when fetching an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_get(). + if isOverlayXattr(opts.Name) { + return "", syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_get(). + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + return vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, opts) +} + +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + return err + } + + err = fs.setXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) if err != nil { return err } - return syserror.ENOTSUP + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent) + return nil +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) setXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { + return err + } + + // Return EOPNOTSUPP when setting an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_set(). + if isOverlayXattr(opts.Name) { + return syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_set(). + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.SetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, opts) } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) + if err != nil { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + return err + } + + err = fs.removeXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), name) + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) if err != nil { return err } - return syserror.ENOTSUP + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0 /* cookie */, vfs.InodeEvent) + return nil +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) removeXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, name string) error { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { + return err + } + + // Like SetXattrAt, return EOPNOTSUPP when removing an overlay attribute. + // Linux passes the remove request to xattr_handler->set. + // See fs/xattr.c:vfs_removexattr(). + if isOverlayXattr(name) { + return syserror.EOPNOTSUPP + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.RemoveXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index 268b32537..853aee951 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -38,6 +39,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { }) } +// +stateify savable type nonDirectoryFD struct { fileDescription @@ -46,7 +48,7 @@ type nonDirectoryFD struct { // fileDescription.dentry().upperVD. cachedFlags is the last known value of // cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are // protected by mu. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` copiedUp bool cachedFD *vfs.FileDescription cachedFlags uint32 @@ -146,6 +148,16 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux return stat, nil } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Allocate(ctx, mode, offset, length) +} + // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() @@ -172,6 +184,9 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) return err } d.updateAfterSetStatLocked(&opts) + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) + } return nil } @@ -256,10 +271,105 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - wrappedFD, err := fd.getCurrentFD(ctx) + if err := fd.ensureMappable(ctx, opts); err != nil { + return err + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd.dentry(), opts) +} + +// ensureMappable ensures that fd.dentry().wrappedMappable is not nil. +func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { + d := fd.dentry() + + // Fast path if we already have a Mappable for the current top layer. + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + + // Only permit mmap of regular files, since other file types may have + // unpredictable behavior when mmapped (e.g. /dev/zero). + if atomic.LoadUint32(&d.mode)&linux.S_IFMT != linux.S_IFREG { + return syserror.ENODEV + } + + // Get a Mappable for the current top layer. + fd.mu.Lock() + defer fd.mu.Unlock() + d.copyMu.RLock() + defer d.copyMu.RUnlock() + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + wrappedFD, err := fd.currentFDLocked(ctx) if err != nil { return err } - defer wrappedFD.DecRef(ctx) - return wrappedFD.ConfigureMMap(ctx, opts) + if err := wrappedFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + // Use this Mappable for all mappings of this layer (unless we raced with + // another call to ensureMappable). + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.dataMu.Lock() + defer d.dataMu.Unlock() + if d.wrappedMappable == nil { + d.wrappedMappable = opts.Mappable + atomic.StoreUint32(&d.isMappable, 1) + } + return nil +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, ar, offset, writable) + } + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable) + if !d.isCopiedUp() { + d.lowerMappings.RemoveMapping(ms, ar, offset, writable) + } +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, dstAR, offset, writable) + } + return nil +} + +// Translate implements memmap.Mappable.Translate. +func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + d.dataMu.RLock() + defer d.dataMu.RUnlock() + return d.wrappedMappable.Translate(ctx, required, optional, at) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (d *dentry) InvalidateUnsavable(ctx context.Context) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + return d.wrappedMappable.InvalidateUnsavable(ctx) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 00562667f..dfbccd05f 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,10 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.dataMu // // Locking dentry.dirMu in multiple dentries requires that parent dentries are // locked before child dentries, and that filesystem.renameMu is locked to @@ -37,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -46,6 +51,8 @@ import ( const Name = "overlay" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -55,6 +62,8 @@ func (FilesystemType) Name() string { // FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to // FilesystemType.GetFilesystem. +// +// +stateify savable type FilesystemOptions struct { // Callers passing FilesystemOptions to // overlay.FilesystemType.GetFilesystem() are responsible for ensuring that @@ -71,6 +80,8 @@ type FilesystemOptions struct { } // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -93,7 +104,7 @@ type filesystem struct { // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different // dentries. - renameMu sync.RWMutex + renameMu sync.RWMutex `state:"nosave"` // lastDirIno is the last inode number assigned to a directory. lastDirIno // is accessed using atomic memory operations. @@ -106,16 +117,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsoptsRaw := opts.InternalData fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) if fsoptsRaw != nil && !haveFSOpts { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) + ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } if haveFSOpts { if len(fsopts.LowerRoots) == 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") return nil, nil, syserror.EINVAL } if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") return nil, nil, syserror.EINVAL } // We don't enforce a maximum number of lower layers when not @@ -132,7 +143,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, "workdir") upperPath := fspath.Parse(upperPathname) if !upperPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -144,13 +155,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) return nil, nil, err } defer upperRoot.DecRef(ctx) privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) return nil, nil, err } defer privateUpperRoot.DecRef(ctx) @@ -158,24 +169,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } lowerPathnamesStr, ok := mopts["lowerdir"] if !ok { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") return nil, nil, syserror.EINVAL } if len(lowerPathnames) > maxLowerLayers { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) return nil, nil, syserror.EINVAL } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) return nil, nil, syserror.EINVAL } lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -187,13 +198,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err } defer privateLowerRoot.DecRef(ctx) @@ -201,7 +212,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } if len(mopts) != 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) + ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } @@ -274,7 +285,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EREMOTE } if isWhiteout(&rootStat) { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") + ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") root.destroyLocked(ctx) fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL @@ -362,6 +373,8 @@ func (fs *filesystem) newDirIno() uint64 { } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -394,7 +407,7 @@ type dentry struct { // and dirents (if not nil) is a cache of dirents as returned by // directoryFDs representing this directory. children is protected by // dirMu. - dirMu sync.Mutex + dirMu sync.Mutex `state:"nosave"` children map[string]*dentry dirents []vfs.Dirent @@ -404,7 +417,7 @@ type dentry struct { // If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e. // be copied up) with copyMu locked for writing; otherwise, it is // immutable. lowerVDs is always immutable. - copyMu sync.RWMutex + copyMu sync.RWMutex `state:"nosave"` upperVD vfs.VirtualDentry lowerVDs []vfs.VirtualDentry @@ -419,7 +432,43 @@ type dentry struct { devMinor uint32 ino uint64 + // If this dentry represents a regular file, then: + // + // - mapsMu is used to synchronize between copy-up and memmap.Mappable + // methods on dentry preceding mm.MemoryManager.activeMu in the lock order. + // + // - dataMu is used to synchronize between copy-up and + // dentry.(memmap.Mappable).Translate. + // + // - lowerMappings tracks memory mappings of the file. lowerMappings is + // used to invalidate mappings of the lower layer when the file is copied + // up to ensure that they remain coherent with subsequent writes to the + // file. (Note that, as of this writing, Linux overlayfs does not do this; + // this feature is a gVisor extension.) lowerMappings is protected by + // mapsMu. + // + // - If this dentry is copied-up, then wrappedMappable is the Mappable + // obtained from a call to the current top layer's + // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil + // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil. + // wrappedMappable is protected by mapsMu and dataMu. + // + // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is + // accessed using atomic memory operations. + mapsMu sync.Mutex + lowerMappings memmap.MappingSet + dataMu sync.RWMutex + wrappedMappable memmap.Mappable + isMappable uint32 + locks vfs.FileLocks + + // watches is the set of inotify watches on the file repesented by this dentry. + // + // Note that hard links to the same file will not share the same set of + // watches, due to the fact that we do not have inode structures in this + // overlay implementation. + watches vfs.Watches } // newDentry creates a new dentry. The dentry initially has no references; it @@ -479,6 +528,14 @@ func (d *dentry) checkDropLocked(ctx context.Context) { if atomic.LoadInt64(&d.refs) != 0 { return } + + // Make sure that we do not lose watches on dentries that have not been + // deleted. Note that overlayfs never calls VFS.InvalidateDentry(), so + // d.vfsd.IsDead() indicates that d was deleted. + if !d.vfsd.IsDead() && d.watches.Size() > 0 { + return + } + // Refs is still zero; destroy it. d.destroyLocked(ctx) return @@ -507,6 +564,8 @@ func (d *dentry) destroyLocked(ctx context.Context) { lowerVD.DecRef(ctx) } + d.watches.HandleDeletion(ctx) + if d.parent != nil { d.parent.dirMu.Lock() if !d.vfsd.IsDead() { @@ -525,19 +584,36 @@ func (d *dentry) destroyLocked(ctx context.Context) { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) { - // TODO(gvisor.dev/issue/1479): Implement inotify. + if d.isDir() { + events |= linux.IN_ISDIR + } + + // overlayfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates + // that d was deleted. + deleted := d.vfsd.IsDead() + + d.fs.renameMu.RLock() + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + d.parent.watches.Notify(ctx, d.name, events, cookie, et, deleted) + } + d.watches.Notify(ctx, "", events, cookie, et, deleted) + d.fs.renameMu.RUnlock() } // Watches implements vfs.DentryImpl.Watches. func (d *dentry) Watches() *vfs.Watches { - // TODO(gvisor.dev/issue/1479): Implement inotify. - return nil + return &d.watches } // OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) OnZeroWatches(context.Context) {} +func (d *dentry) OnZeroWatches(ctx context.Context) { + if atomic.LoadInt64(&d.refs) == 0 { + d.fs.renameMu.Lock() + d.checkDropLocked(ctx) + d.fs.renameMu.Unlock() + } +} // iterLayers invokes yield on each layer comprising d, from top to bottom. If // any call to yield returns false, iterLayer stops iteration. @@ -570,6 +646,16 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + // statInternalMask is the set of stat fields that is set by // dentry.statInternalTo(). const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -608,6 +694,8 @@ func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) { // fileDescription is embedded by overlay implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -622,6 +710,48 @@ func (fd *fileDescription) dentry() *dentry { return fd.vfsfd.Dentry().Impl().(*dentry) } +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.filesystem().listXattr(ctx, fd.dentry(), size) +} + +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.filesystem().getXattr(ctx, fd.dentry(), auth.CredentialsFromContext(ctx), &opts) +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { + fs := fd.filesystem() + d := fd.dentry() + + fs.renameMu.RLock() + err := fs.setXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), &opts) + fs.renameMu.RUnlock() + if err != nil { + return err + } + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil +} + +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { + fs := fd.filesystem() + d := fd.dentry() + + fs.renameMu.RLock() + err := fs.removeXattrLocked(ctx, d, fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), name) + fs.renameMu.RUnlock() + if err != nil { + return err + } + + d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 7053ad6db..4e2da4810 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type filesystemType struct{} // Name implements vfs.FilesystemType.Name. @@ -43,6 +44,7 @@ func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile panic("pipefs.filesystemType.GetFilesystem should never be called") } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -76,6 +78,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { kernfs.InodeNotDirectory kernfs.InodeNotSymlink @@ -144,8 +148,8 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth. } // Open implements kernfs.Inode.Open. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks) +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return i.pipe.Open(ctx, rp.Mount(), d.VFSDentry(), opts.Flags, &i.locks) } // StatFS implements kernfs.Inode.StatFS. diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index a45b44440..2e086e34c 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -100,6 +100,7 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", + "//pkg/tcpip/network/ipv4", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 03b5941b9..05d7948ea 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -41,6 +41,7 @@ func (FilesystemType) Name() string { return Name } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -84,6 +85,8 @@ func (fs *filesystem) Release(ctx context.Context) { // dynamicInode is an overfitted interface for common Inodes with // dynamicByteSource types used in procfs. +// +// +stateify savable type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource @@ -99,6 +102,7 @@ func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux. return d } +// +stateify savable type staticFile struct { kernfs.DynamicBytesFile vfs.StaticData @@ -118,10 +122,13 @@ func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 // InternalData contains internal data passed in to the procfs mount via // vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable type InternalData struct { Cgroups map[string]string } +// +stateify savable type implStatFS struct{} // StatFS implements kernfs.Inode.StatFS. diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index d57d94dbc..47ecd941c 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -68,8 +68,8 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *subtasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { tid, err := strconv.ParseUint(name, 10, 32) if err != nil { return nil, syserror.ENOENT @@ -82,12 +82,10 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, e if subTask.ThreadGroup() != i.task.ThreadGroup() { return nil, syserror.ENOENT } - - subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers) - return subTaskDentry.VFSDentry(), nil + return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { @@ -118,6 +116,7 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb return offset, nil } +// +stateify savable type subtasksFD struct { kernfs.GenericDirectoryFD @@ -155,21 +154,21 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro return fd.GenericDirectoryFD.SetStat(ctx, opts) } -// Open implements kernfs.Inode. -func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &subtasksFD{task: i.task} if err := fd.Init(&i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }); err != nil { return nil, err } - if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { @@ -181,12 +180,12 @@ func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs return stat, nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *subtasksInode) DecRef(context.Context) { i.subtasksInodeRefs.DecRef(i.Destroy) } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index dbdb5d929..a7cd6f57e 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -53,6 +53,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}), "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), "comm": fs.newComm(task, fs.NextIno(), 0444), + "cwd": fs.newCwdSymlink(task, fs.NextIno()), "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), "exe": fs.newExeSymlink(task, fs.NextIno()), "fd": fs.newFDDirInode(task), @@ -106,9 +107,9 @@ func (i *taskInode) Valid(ctx context.Context) bool { return i.task.ExitState() != kernel.TaskExitDead } -// Open implements kernfs.Inode. -func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ +// Open implements kernfs.Inode.Open. +func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) if err != nil { @@ -117,18 +118,20 @@ func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D return fd.VFSFileDescription(), nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *taskInode) DecRef(context.Context) { i.taskInodeRefs.DecRef(i.Destroy) } // taskOwnedInode implements kernfs.Inode and overrides inode owner with task // effective user and group. +// +// +stateify savable type taskOwnedInode struct { kernfs.Inode @@ -168,7 +171,7 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. return d } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { @@ -186,7 +189,7 @@ func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs. return stat, nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { mode := i.Mode() uid, gid := i.getOwner(mode) diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 3f0d78461..0866cea2b 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -51,6 +51,7 @@ func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool { return true } +// +stateify savable type fdDir struct { locks vfs.FileLocks @@ -62,7 +63,7 @@ type fdDir struct { produceSymlink bool } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { @@ -86,14 +87,19 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off Name: strconv.FormatUint(uint64(fd), 10), Type: typ, Ino: i.fs.NextIno(), - NextOff: offset + 1, + NextOff: int64(fd) + 3, } if err := cb.Handle(dirent); err != nil { - return offset, err + // Getdents should iterate correctly despite mutation + // of fds, so we return the next fd to serialize plus + // 2 (which accounts for the "." and ".." tracked by + // kernfs) as the offset. + return int64(fd) + 2, err } - offset++ } - return offset, nil + // We serialized them all. Next offset should be higher than last + // serialized fd. + return int64(fds[len(fds)-1]) + 3, nil } // fdDirInode represents the inode for /proc/[pid]/fd directory. @@ -130,8 +136,8 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *fdDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { return nil, syserror.ENOENT @@ -140,13 +146,12 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } - taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()) - return taskDentry.VFSDentry(), nil + return i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()), nil } -// Open implements kernfs.Inode. -func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ +// Open implements kernfs.Inode.Open. +func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) if err != nil { @@ -155,7 +160,7 @@ func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs. return fd.VFSFileDescription(), nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. // // This is to match Linux, which uses a special permission handler to guarantee // that a process can still access /proc/self/fd after it has executed @@ -177,7 +182,7 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia return err } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *fdDirInode) DecRef(context.Context) { i.fdDirInodeRefs.DecRef(i.Destroy) } @@ -209,7 +214,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *ker return d } -func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { +func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { file, _ := getTaskFD(s.task, s.fd) if file == nil { return "", syserror.ENOENT @@ -264,8 +269,8 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { return nil, syserror.ENOENT @@ -278,13 +283,12 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, task: i.task, fd: fd, } - dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data) - return dentry.VFSDentry(), nil + return i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data), nil } -// Open implements kernfs.Inode. -func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ +// Open implements kernfs.Inode.Open. +func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) if err != nil { @@ -293,7 +297,7 @@ func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd * return fd.VFSFileDescription(), nil } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *fdInfoDirInode) DecRef(context.Context) { i.fdInfoDirInodeRefs.DecRef(i.Destroy) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 356036b9b..3fbf081a6 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -543,7 +543,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { var vss, rss, data uint64 s.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() @@ -667,20 +667,24 @@ func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentr return d } -// Readlink implements kernfs.Inode. -func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { - if !kernel.ContextCanTrace(ctx, s.task, false) { - return "", syserror.EACCES - } - - // Pull out the executable for /proc/[pid]/exe. - exec, err := s.executable() +// Readlink implements kernfs.Inode.Readlink. +func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { + exec, _, err := s.Getlink(ctx, nil) if err != nil { return "", err } defer exec.DecRef(ctx) - return exec.PathnameWithDeleted(ctx), nil + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + vfsObj := exec.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, exec) + return name, nil } // Getlink implements kernfs.Inode.Getlink. @@ -688,23 +692,12 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent if !kernel.ContextCanTrace(ctx, s.task, false) { return vfs.VirtualDentry{}, "", syserror.EACCES } - - exec, err := s.executable() - if err != nil { - return vfs.VirtualDentry{}, "", err - } - defer exec.DecRef(ctx) - - vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() - vd.IncRef() - return vd, "", nil -} - -func (s *exeSymlink) executable() (file fsbridge.File, err error) { if err := checkTaskState(s.task); err != nil { - return nil, err + return vfs.VirtualDentry{}, "", err } + var err error + var exec fsbridge.File s.task.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { @@ -715,12 +708,78 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) { // The MemoryManager may be destroyed, in which case // MemoryManager.destroy will simply set the executable to nil // (with locks held). - file = mm.Executable() - if file == nil { + exec = mm.Executable() + if exec == nil { err = syserror.ESRCH } }) - return + if err != nil { + return vfs.VirtualDentry{}, "", err + } + defer exec.DecRef(ctx) + + vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() + vd.IncRef() + return vd, "", nil +} + +// cwdSymlink is an symlink for the /proc/[pid]/cwd file. +// +// +stateify savable +type cwdSymlink struct { + implStatFS + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + task *kernel.Task +} + +var _ kernfs.Inode = (*cwdSymlink)(nil) + +func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { + inode := &cwdSymlink{task: task} + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +// Readlink implements kernfs.Inode.Readlink. +func (s *cwdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { + cwd, _, err := s.Getlink(ctx, nil) + if err != nil { + return "", err + } + defer cwd.DecRef(ctx) + + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + vfsObj := cwd.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, cwd) + return name, nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (s *cwdSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { + if !kernel.ContextCanTrace(ctx, s.task, false) { + return vfs.VirtualDentry{}, "", syserror.EACCES + } + if err := checkTaskState(s.task); err != nil { + return vfs.VirtualDentry{}, "", err + } + cwd := s.task.FSContext().WorkingDirectoryVFS2() + if !cwd.Ok() { + // It could have raced with process deletion. + return vfs.VirtualDentry{}, "", syserror.ESRCH + } + return cwd, "", nil } // mountInfoData is used to implement /proc/[pid]/mountinfo. @@ -785,6 +844,7 @@ func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } +// +stateify savable type namespaceSymlink struct { kernfs.StaticSymlink @@ -807,15 +867,15 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri return d } -// Readlink implements Inode. -func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) { +// Readlink implements kernfs.Inode.Readlink. +func (s *namespaceSymlink) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { if err := checkTaskState(s.task); err != nil { return "", err } - return s.StaticSymlink.Readlink(ctx) + return s.StaticSymlink.Readlink(ctx, mnt) } -// Getlink implements Inode.Getlink. +// Getlink implements kernfs.Inode.Getlink. func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { if err := checkTaskState(s.task); err != nil { return vfs.VirtualDentry{}, "", err @@ -832,6 +892,8 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // namespaceInode is a synthetic inode created to represent a namespace in // /proc/[pid]/ns/*. +// +// +stateify savable type namespaceInode struct { implStatFS kernfs.InodeAttrs @@ -852,12 +914,12 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32 i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } -// Open implements Inode.Open. -func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +// Open implements kernfs.Inode.Open. +func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &namespaceFD{inode: i} i.IncRef() fd.LockFD.Init(&i.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return &fd.vfsfd, nil @@ -865,6 +927,8 @@ func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd * // namespace FD is a synthetic file that represents a namespace in // /proc/[pid]/ns/*. +// +// +stateify savable type namespaceFD struct { vfs.FileDescriptionDefaultImpl vfs.LockFD @@ -875,20 +939,20 @@ type namespaceFD struct { var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) -// Stat implements FileDescriptionImpl. +// Stat implements vfs.FileDescriptionImpl.Stat. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() return fd.inode.Stat(ctx, vfs, opts) } -// SetStat implements FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() creds := auth.CredentialsFromContext(ctx) return fd.inode.SetStat(ctx, vfs, creds, opts) } -// Release implements FileDescriptionImpl. +// Release implements vfs.FileDescriptionImpl.Release. func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 4e69782c7..e7f748655 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -616,6 +616,7 @@ type netSnmpData struct { var _ dynamicInode = (*netSnmpData)(nil) +// +stateify savable type snmpLine struct { prefix string header string @@ -660,7 +661,7 @@ func sprintSlice(s []uint64) string { return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. } -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { types := []interface{}{ &inet.StatSNMPIP{}, @@ -709,7 +710,7 @@ type netRouteData struct { var _ dynamicInode = (*netRouteData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT") @@ -773,7 +774,7 @@ type netStatData struct { var _ dynamicInode = (*netStatData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " + diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3ea00ab87..d8f5dd509 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -52,8 +52,8 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - selfSymlink *vfs.Dentry - threadSelfSymlink *vfs.Dentry + selfSymlink *kernfs.Dentry + threadSelfSymlink *kernfs.Dentry // cgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed @@ -81,8 +81,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace inode := &tasksInode{ pidns: pidns, fs: fs, - selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), - threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(), + selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns), + threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns), cgroupControllers: cgroupControllers, } inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) @@ -98,8 +98,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace return inode, dentry } -// Lookup implements kernfs.inodeDynamicLookup. -func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { +// Lookup implements kernfs.inodeDynamicLookup.Lookup. +func (i *tasksInode) Lookup(ctx context.Context, name string) (*kernfs.Dentry, error) { // Try to lookup a corresponding task. tid, err := strconv.ParseUint(name, 10, 64) if err != nil { @@ -118,11 +118,10 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return nil, syserror.ENOENT } - taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers) - return taskDentry.VFSDentry(), nil + return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 @@ -200,9 +199,9 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback return maxTaskID, nil } -// Open implements kernfs.Inode. -func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ +// Open implements kernfs.Inode.Open. +func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), d, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) if err != nil { @@ -229,7 +228,7 @@ func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.St return stat, nil } -// DecRef implements kernfs.Inode. +// DecRef implements kernfs.Inode.DecRef. func (i *tasksInode) DecRef(context.Context) { i.tasksInodeRefs.DecRef(i.Destroy) } @@ -237,6 +236,8 @@ func (i *tasksInode) DecRef(context.Context) { // staticFileSetStat implements a special static file that allows inode // attributes to be set. This is to support /proc files that are readonly, but // allow attributes to be set. +// +// +stateify savable type staticFileSetStat struct { dynamicBytesFileSetAttr vfs.StaticData diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 8c41729e4..f268c59b0 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type selfSymlink struct { implStatFS kernfs.InodeAttrs @@ -51,7 +52,7 @@ func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns return d } -func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -64,16 +65,17 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { return strconv.FormatUint(uint64(tgid), 10), nil } -func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *selfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// +stateify savable type threadSelfSymlink struct { implStatFS kernfs.InodeAttrs @@ -94,7 +96,7 @@ func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, return d } -func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -108,12 +110,12 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { return fmt.Sprintf("%d/task/%d", tgid, tid), nil } -func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *threadSelfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } @@ -121,16 +123,20 @@ func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Creden // dynamicBytesFileSetAttr implements a special file that allows inode // attributes to be set. This is to support /proc files that are readonly, but // allow attributes to be set. +// +// +stateify savable type dynamicBytesFileSetAttr struct { kernfs.DynamicBytesFile } -// SetStat implements Inode.SetStat. +// SetStat implements kernfs.Inode.SetStat. func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts) } // cpuStats contains the breakdown of CPU time for /proc/stat. +// +// +stateify savable type cpuStats struct { // user is time spent in userspace tasks with non-positive niceness. user uint64 diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 038a194c7..3312b0418 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -27,9 +27,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type tcpMemDir int const ( @@ -67,6 +69,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}), "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -174,7 +177,7 @@ type tcpSackData struct { var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { if d.enabled == nil { sack, err := d.stack.TCPSACKEnabled() @@ -232,7 +235,7 @@ type tcpRecoveryData struct { var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error { recovery, err := d.stack.TCPRecovery() if err != nil { @@ -284,7 +287,7 @@ type tcpMemData struct { var _ vfs.WritableDynamicBytesSource = (*tcpMemData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error { d.mu.Lock() defer d.mu.Unlock() @@ -354,3 +357,63 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) } } + +// ipForwarding implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_forwarding. +// +// +stateify savable +type ipForwarding struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + enabled *bool +} + +var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error { + if ipf.enabled == nil { + enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber) + ipf.enabled = &enabled + } + + val := "0\n" + if *ipf.enabled { + // Technically, this is not quite compatible with Linux. Linux stores these + // as an integer, so if you write "2" into tcp_sack, you should get 2 back. + // Tough luck. + val = "1\n" + } + buf.WriteString(val) + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if ipf.enabled == nil { + ipf.enabled = new(bool) + } + *ipf.enabled = v != 0 + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + return 0, err + } + return n, nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index be54897bb..6cee22823 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -20,8 +20,10 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/usermem" ) func newIPv6TestStack() *inet.TestStack { @@ -76,3 +78,72 @@ func TestIfinet6(t *testing.T) { t.Errorf("Got n.contents() = %v, want = %v", got, want) } } + +// TestIPForwarding tests the implementation of +// /proc/sys/net/ipv4/ip_forwarding +func TestConfigureIPForwarding(t *testing.T) { + ctx := context.Background() + s := inet.NewTestStack() + + var cases = []struct { + comment string + initial bool + str string + final bool + }{ + { + comment: `Forwarding is disabled; write 1 and enable forwarding`, + initial: false, + str: "1", + final: true, + }, + { + comment: `Forwarding is disabled; write 0 and disable forwarding`, + initial: false, + str: "0", + final: false, + }, + { + comment: `Forwarding is enabled; write 1 and enable forwarding`, + initial: true, + str: "1", + final: true, + }, + { + comment: `Forwarding is enabled; write 0 and disable forwarding`, + initial: true, + str: "0", + final: false, + }, + { + comment: `Forwarding is disabled; write 2404 and enable forwarding`, + initial: false, + str: "2404", + final: true, + }, + { + comment: `Forwarding is enabled; write 2404 and enable forwarding`, + initial: true, + str: "2404", + final: true, + }, + } + for _, c := range cases { + t.Run(c.comment, func(t *testing.T) { + s.IPForwarding = c.initial + + file := &ipForwarding{stack: s, enabled: &c.initial} + + // Write the values. + src := usermem.BytesIOSequence([]byte(c.str)) + if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) + } + + // Read the values from the stack and check them. + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) + } + }) + } +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d82b3d2f3..6975af5a7 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -67,6 +67,7 @@ var ( taskStaticFiles = map[string]testutil.DirentType{ "auxv": linux.DT_REG, "cgroup": linux.DT_REG, + "cwd": linux.DT_LNK, "cmdline": linux.DT_REG, "comm": linux.DT_REG, "environ": linux.DT_REG, @@ -104,7 +105,7 @@ func setup(t *testing.T) *testutil.System { AllowUserMount: true, }) - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("NewMountNamespace(): %v", err) } diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 6297e1df4..bf11b425a 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -26,7 +26,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SignalFileDescription implements FileDescriptionImpl for signal fds. +// SignalFileDescription implements vfs.FileDescriptionImpl for signal fds. +// +// +stateify savable type SignalFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -43,7 +45,7 @@ type SignalFileDescription struct { target *kernel.Task // mu protects mask. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // mask is the signal mask. Protected by mu. mask linux.SignalSet @@ -83,7 +85,7 @@ func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) { sfd.mask = mask } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { // Attempt to dequeue relevant signals. info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0) @@ -132,5 +134,5 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) { sfd.target.SignalUnregister(entry) } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (sfd *SignalFileDescription) Release(context.Context) {} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index 94a998568..29e5371d6 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -28,14 +28,16 @@ import ( ) // filesystemType implements vfs.FilesystemType. +// +// +stateify savable type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("sockfs.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. // // Note that registering sockfs is unnecessary, except for the fact that it // will not show up under /proc/filesystems as a result. This is a very minor @@ -44,6 +46,7 @@ func (filesystemType) Name() string { return "sockfs" } +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -80,6 +83,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } // inode implements kernfs.Inode. +// +// +stateify savable type inode struct { kernfs.InodeAttrs kernfs.InodeNoopRefCount @@ -88,7 +93,7 @@ type inode struct { } // Open implements kernfs.Inode.Open. -func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return nil, syserror.ENXIO } diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 73f3d3309..b75d70ae6 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -36,6 +36,8 @@ func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) } // kcovInode implements kernfs.Inode. +// +// +stateify savable type kcovInode struct { kernfs.InodeAttrs kernfs.InodeNoopRefCount @@ -44,7 +46,7 @@ type kcovInode struct { implStatFS } -func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { k := kernel.KernelFromContext(ctx) if k == nil { panic("KernelFromContext returned nil") @@ -54,7 +56,7 @@ func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D kcov: k.NewKcov(), } - if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{ + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, }); err != nil { @@ -63,6 +65,7 @@ func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D return &fd.vfsfd, nil } +// +stateify savable type kcovFD struct { vfs.FileDescriptionDefaultImpl vfs.NoLockFD diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 39952d2d0..1568c581f 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -34,9 +34,13 @@ const Name = "sysfs" const defaultSysDirMode = linux.FileMode(0755) // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { kernfs.Filesystem @@ -117,6 +121,8 @@ func (fs *filesystem) Release(ctx context.Context) { } // dir implements kernfs.Inode. +// +// +stateify savable type dir struct { dirRefs kernfs.InodeAttrs @@ -148,8 +154,8 @@ func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.Set } // Open implements kernfs.Inode.Open. -func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndStaticEntries, }) if err != nil { @@ -169,6 +175,8 @@ func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, err } // cpuFile implements kernfs.Inode. +// +// +stateify savable type cpuFile struct { implStatFS kernfs.DynamicBytesFile @@ -190,6 +198,7 @@ func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode li return d } +// +stateify savable type implStatFS struct{} // StatFS implements kernfs.Inode.StatFS. diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 9fd38b295..0a0d914cc 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -38,7 +38,7 @@ func newTestSystem(t *testing.T) *testutil.System { AllowUserMount: true, }) - mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{}) + mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create new mount namespace: %v", err) } diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 86beaa0a8..8853c8ad2 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -26,8 +26,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TimerFileDescription implements FileDescriptionImpl for timer fds. It also +// TimerFileDescription implements vfs.FileDescriptionImpl for timer fds. It also // implements ktime.TimerListener. +// +// +stateify savable type TimerFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -62,7 +64,7 @@ func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, return &tfd.vfsfd, nil } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { const sizeofUint64 = 8 if dst.NumBytes() < sizeofUint64 { @@ -128,7 +130,7 @@ func (tfd *TimerFileDescription) ResumeTimer() { tfd.timer.Resume() } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (tfd *TimerFileDescription) Release(context.Context) { tfd.timer.Destroy() } diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index e5a4218e8..5209a17af 100644 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -182,7 +182,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -376,7 +376,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go index ac54d420d..9129d35b7 100644 --- a/pkg/sentry/fsimpl/tmpfs/device_file.go +++ b/pkg/sentry/fsimpl/tmpfs/device_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" ) +// +stateify savable type deviceFile struct { inode inode kind vfs.DeviceKind diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 070c75e68..e90669cf0 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// +stateify savable type directory struct { // Since directories can't be hard-linked, each directory can only be // associated with a single dentry, which we can store in the directory @@ -44,7 +45,7 @@ type directory struct { // (with inode == nil) that represent the iteration position of // directoryFDs. childList is used to support directoryFD.IterDirents() // efficiently. childList is protected by iterMu. - iterMu sync.Mutex + iterMu sync.Mutex `state:"nosave"` childList dentryList } @@ -86,6 +87,7 @@ func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error { return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid))) } +// +stateify savable type directoryFD struct { fileDescription vfs.DirectoryFileDescriptionDefaultImpl diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e0de04e05..e39cd305b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -673,11 +673,11 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts fs.mu.RUnlock() return err } - if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil { - fs.mu.RUnlock() + err = d.inode.setStat(ctx, rp.Credentials(), &opts) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent) @@ -770,7 +770,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() defer fs.mu.RUnlock() @@ -792,59 +792,59 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } - return d.inode.listxattr(size) + return d.inode.listXattr(size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() d, err := resolveLocked(ctx, rp) if err != nil { return "", err } - return d.inode.getxattr(rp.Credentials(), &opts) + return d.inode.getXattr(rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { - fs.mu.RUnlock() + err = d.inode.setXattr(rp.Credentials(), &opts) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.removexattr(rp.Credentials(), name); err != nil { - fs.mu.RUnlock() + err = d.inode.removeXattr(rp.Credentials(), name) + fs.mu.RUnlock() + if err != nil { return err } - fs.mu.RUnlock() d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil @@ -865,8 +865,16 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } if d.parent == nil { if d.name != "" { - // This must be an anonymous memfd file. + // This file must have been created by + // newUnlinkedRegularFileDescription(). In Linux, + // mm/shmem.c:__shmem_file_setup() => + // fs/file_table.c:alloc_file_pseudo() sets the created + // dentry's dentry_operations to anon_ops, for which d_dname == + // simple_dname. fs/d_path.c:simple_dname() defines the + // dentry's pathname to be its name, prefixed with "/" and + // suffixed with " (deleted)". b.PrependComponent("/" + d.name) + b.AppendString(" (deleted)") return vfs.PrependPathSyntheticError{} } return vfs.PrependPathAtNonMountRootError{} diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 5b0471ff4..d772db9e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// +stateify savable type namedPipe struct { inode inode diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index ec2701d8b..be29a2363 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -158,7 +158,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 0710b65db..a199eb33d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -36,12 +36,18 @@ import ( ) // regularFile is a regular (=S_IFREG) tmpfs file. +// +// +stateify savable type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. memFile *pgalloc.MemoryFile + // memoryUsageKind is the memory accounting category under which pages backing + // this regularFile's contents are accounted. + memoryUsageKind usage.MemoryKind + // mapsMu protects mappings. mapsMu sync.Mutex `state:"nosave"` @@ -62,7 +68,7 @@ type regularFile struct { writableMappingPages uint64 // dataMu protects the fields below. - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` // data maps offsets into the file to offsets into memFile that store // the file's data. @@ -86,14 +92,75 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, - seals: linux.F_SEAL_SEAL, + memFile: fs.memFile, + memoryUsageKind: usage.Tmpfs, + seals: linux.F_SEAL_SEAL, } file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode) file.inode.nlink = 1 // from parent directory return &file.inode } +// newUnlinkedRegularFileDescription creates a regular file on the tmpfs +// filesystem represented by mount and returns an FD representing that file. +// The new file is not reachable by path traversal from any other file. +// +// newUnlinkedRegularFileDescription is analogous to Linux's +// mm/shmem.c:__shmem_file_setup(). +// +// Preconditions: mount must be a tmpfs mount. +func newUnlinkedRegularFileDescription(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, name string) (*regularFileFD, error) { + fs, ok := mount.Filesystem().Impl().(*filesystem) + if !ok { + panic("tmpfs.newUnlinkedRegularFileDescription() called with non-tmpfs mount") + } + + inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) + d := fs.newDentry(inode) + defer d.DecRef(ctx) + d.name = name + + fd := ®ularFileFD{} + fd.Init(&inode.locks) + flags := uint32(linux.O_RDWR) + if err := fd.vfsfd.Init(fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return fd, nil +} + +// NewZeroFile creates a new regular file and file description as for +// mmap(MAP_SHARED | MAP_ANONYMOUS). The file has the given size and is +// initially (implicitly) filled with zeroes. +// +// Preconditions: mount must be a tmpfs mount. +func NewZeroFile(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, size uint64) (*vfs.FileDescription, error) { + // Compare mm/shmem.c:shmem_zero_setup(). + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, "dev/zero") + if err != nil { + return nil, err + } + rf := fd.inode().impl.(*regularFile) + rf.memoryUsageKind = usage.Anonymous + rf.size = size + return &fd.vfsfd, err +} + +// NewMemfd creates a new regular file and file description as for +// memfd_create. +// +// Preconditions: mount must be a tmpfs mount. +func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, name) + if err != nil { + return nil, err + } + if allowSeals { + fd.inode().impl.(*regularFile).seals = 0 + } + return &fd.vfsfd, nil +} + // truncate grows or shrinks the file to the given size. It returns true if the // file size was updated. func (rf *regularFile) truncate(newSize uint64) (bool, error) { @@ -226,7 +293,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap. optional.End = pgend } - cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { + cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { // Newly-allocated pages are zeroed, so we don't need to do anything. return dsts.NumBytes(), nil }) @@ -260,13 +327,14 @@ func (*regularFile) InvalidateUnsavable(context.Context) error { return nil } +// +stateify savable type regularFileFD struct { fileDescription // off is the file offset. off is accessed using atomic memory operations. // offMu serializes operations that may mutate off. off int64 - offMu sync.Mutex + offMu sync.Mutex `state:"nosave"` } // Release implements vfs.FileDescriptionImpl.Release. @@ -575,7 +643,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, case gap.Ok(): // Allocate memory for the write. gapMR := gap.Range().Intersect(pgMR) - fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs) + fr, err := rw.file.memFile.Allocate(gapMR.Length(), rw.file.memoryUsageKind) if err != nil { retErr = err goto exitLoop diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go index 3ed650474..5699d5975 100644 --- a/pkg/sentry/fsimpl/tmpfs/socket_file.go +++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go @@ -21,6 +21,8 @@ import ( ) // socketFile is a socket (=S_IFSOCK) tmpfs file. +// +// +stateify savable type socketFile struct { inode inode ep transport.BoundEndpoint diff --git a/pkg/sentry/fsimpl/tmpfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go index b0de5fabe..a102a2ee2 100644 --- a/pkg/sentry/fsimpl/tmpfs/symlink.go +++ b/pkg/sentry/fsimpl/tmpfs/symlink.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) +// +stateify savable type symlink struct { inode inode target string // immutable diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index c4cec4130..cefec8fde 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -51,9 +51,13 @@ import ( const Name = "tmpfs" // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -67,7 +71,7 @@ type filesystem struct { devMinor uint32 // mu serializes changes to the Dentry tree. - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` nextInoMinusOne uint64 // accessed using atomic memory operations } @@ -78,6 +82,8 @@ func (FilesystemType) Name() string { } // FilesystemOpts is used to pass configuration data to tmpfs. +// +// +stateify savable type FilesystemOpts struct { // RootFileType is the FileType of the filesystem root. Valid values // are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR. @@ -221,6 +227,8 @@ var globalStatfs = linux.Statfs{ } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -300,6 +308,8 @@ func (d *dentry) Watches() *vfs.Watches { func (d *dentry) OnZeroWatches(context.Context) {} // inode represents a filesystem object. +// +// +stateify savable type inode struct { // fs is the owning filesystem. fs is immutable. fs *filesystem @@ -316,12 +326,12 @@ type inode struct { // Inode metadata. Writing multiple fields atomically requires holding // mu, othewise atomic operations can be used. - mu sync.Mutex - mode uint32 // file type and mode - nlink uint32 // protected by filesystem.mu instead of inode.mu - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - ino uint64 // immutable + mu sync.Mutex `state:"nosave"` + mode uint32 // file type and mode + nlink uint32 // protected by filesystem.mu instead of inode.mu + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + ino uint64 // immutable // Linux's tmpfs has no concept of btime. atime int64 // nanoseconds @@ -626,74 +636,50 @@ func (i *inode) touchCMtimeLocked() { atomic.StoreInt64(&i.ctime, now) } -func (i *inode) listxattr(size uint64) ([]string, error) { - return i.xattrs.Listxattr(size) +func (i *inode) listXattr(size uint64) ([]string, error) { + return i.xattrs.ListXattr(size) } -func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { +func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - return i.xattrs.Getxattr(opts) + return i.xattrs.GetXattr(opts) } -func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error { +func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error { if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - return i.xattrs.Setxattr(opts) + return i.xattrs.SetXattr(opts) } -func (i *inode) removexattr(creds *auth.Credentials, name string) error { +func (i *inode) removeXattr(creds *auth.Credentials, name string) error { if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - return i.xattrs.Removexattr(name) + return i.xattrs.RemoveXattr(name) } func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { - switch { - case ats&vfs.MayRead == vfs.MayRead: - if err := i.checkPermissions(creds, vfs.MayRead); err != nil { - return err - } - case ats&vfs.MayWrite == vfs.MayWrite: - if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { - return err - } - default: - panic(fmt.Sprintf("checkXattrPermissions called with impossible AccessTypes: %v", ats)) + // We currently only support extended attributes in the user.* and + // trusted.* namespaces. See b/148380782. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { + return syserror.EOPNOTSUPP } - - switch { - case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX): - // The trusted.* namespace can only be accessed by privileged - // users. - if creds.HasCapability(linux.CAP_SYS_ADMIN) { - return nil - } - if ats&vfs.MayWrite == vfs.MayWrite { - return syserror.EPERM - } - return syserror.ENODATA - case strings.HasPrefix(name, linux.XATTR_USER_PREFIX): - // Extended attributes in the user.* namespace are only - // supported for regular files and directories. - filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode) - if filetype == linux.S_IFREG || filetype == linux.S_IFDIR { - return nil - } - if ats&vfs.MayWrite == vfs.MayWrite { - return syserror.EPERM - } - return syserror.ENODATA - + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err } - return syserror.EOPNOTSUPP + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) } // fileDescription is embedded by tmpfs implementations of // vfs.FileDescriptionImpl. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -738,20 +724,20 @@ func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { return globalStatfs, nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.inode().listxattr(size) +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.inode().listXattr(size) } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts) +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.inode().getXattr(auth.CredentialsFromContext(ctx), &opts) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.inode.setXattr(auth.CredentialsFromContext(ctx), &opts); err != nil { return err } @@ -760,10 +746,10 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.inode.removeXattr(auth.CredentialsFromContext(ctx), name); err != nil { return err } @@ -772,37 +758,6 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { return nil } -// NewMemfd creates a new tmpfs regular file and file description that can back -// an anonymous fd created by memfd_create. -func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { - fs, ok := mount.Filesystem().Impl().(*filesystem) - if !ok { - panic("NewMemfd() called with non-tmpfs mount") - } - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with - // S_IRWXUGO. - inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) - rf := inode.impl.(*regularFile) - if allowSeals { - rf.seals = 0 - } - - d := fs.newDentry(inode) - defer d.DecRef(ctx) - d.name = name - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with - // FMODE_READ | FMODE_WRITE. - var fd regularFileFD - fd.Init(&inode.locks) - flags := uint32(linux.O_RDWR) - if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go index 6f3e3ae6f..99c8e3c0f 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go @@ -41,7 +41,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 326c4ed90..0ca750281 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") licenses(["notice"]) @@ -13,12 +13,35 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/sentry/arch", "//pkg/sentry/fs/lock", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", + ], +) + +go_test( + name = "verity_test", + srcs = [ + "verity_test.go", + ], + library = ":verity", + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/contexttest", + "//pkg/sentry/vfs", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 0e17dbddc..a560b0797 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "strconv" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -179,23 +180,19 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // corresponding Merkle tree file. // This is the offset of the root hash for child in its parent's Merkle // tree file. - off, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{ + off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ Root: child.lowerMerkleVD, Start: child.lowerMerkleVD, - }, &vfs.GetxattrOptions{ + }, &vfs.GetXattrOptions{ Name: merkleOffsetInParentXattr, - // Offset is a 32 bit integer. - Size: sizeOfInt32, + Size: sizeOfStringInt32, }) // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -204,10 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - if noCrashOnVerificationFailure { - return nil, syserror.EINVAL - } - panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's root hash. @@ -221,10 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,19 +224,16 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // dataSize is the size of raw data for the Merkle tree. For a file, // dataSize is the size of the whole file. For a directory, dataSize is // the size of all its children's root hashes. - dataSize, err := parentMerkleFD.Getxattr(ctx, &vfs.GetxattrOptions{ + dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, - Size: sizeOfInt32, + Size: sizeOfStringInt32, }) // The Merkle tree file for the child should have been created and // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - if noCrashOnVerificationFailure { - return nil, err - } - panic(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -255,10 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - if noCrashOnVerificationFailure { - return nil, syserror.EINVAL - } - panic(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -270,11 +255,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contain the root hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { - if noCrashOnVerificationFailure { - return nil, syserror.EIO - } - panic(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { + return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child root hash when it's verified the first time. @@ -370,10 +352,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - if noCrashOnVerificationFailure { - return nil, childErr - } - panic(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -392,6 +371,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, Path: fspath.Parse(childMerkleFilename), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, }) if err != nil { return nil, err @@ -409,11 +389,16 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - if noCrashOnVerificationFailure { - return nil, childMerkleErr - } - panic(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } + } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { + // Both the child and the corresponding Merkle tree are missing. + // This could be an unexpected modification or due to incorrect + // parameter. + // TODO(b/167752508): Investigate possible ways to differentiate + // cases that both files are deleted from cases that they never + // exist in the file system. + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -572,8 +557,183 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - //TODO(b/159261227): Implement OpenAt. - return nil, nil + // Verity fs is read-only. + if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 { + return nil, syserror.EROFS + } + + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + + start := rp.Start().Impl().(*dentry) + if rp.Done() { + return start.openLocked(ctx, rp, &opts) + } + +afterTrailingSymlink: + parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) + if err != nil { + return nil, err + } + + // Check for search permission in the parent directory. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + + // Open existing child or follow symlink. + parent.dirMu.Lock() + child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds) + parent.dirMu.Unlock() + if err != nil { + return nil, err + } + if child.isSymlink() && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + start = parent + goto afterTrailingSymlink + } + return child.openLocked(ctx, rp, &opts) +} + +// Preconditions: fs.renameMu must be locked. +func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { + // Users should not open the Merkle tree files. Those are for verity fs + // use only. + if strings.Contains(d.name, merklePrefix) { + return nil, syserror.EPERM + } + ats := vfs.AccessTypesForOpenFlags(opts) + if err := d.checkPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + + // Verity fs is read-only. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EROFS + } + + // Get the path to the target file. This is only used to provide path + // information in failure case. + path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD) + if err != nil { + return nil, err + } + + // Open the file in the underlying file system. + lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }, opts) + + // The file should exist, as we succeeded in finding its dentry. If it's + // missing, it indicates an unexpected modification to the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + } + return nil, err + } + + // lowerFD needs to be cleaned up if any error occurs. IncRef will be + // called if a verity FD is successfully created. + defer lowerFD.DecRef(ctx) + + // Open the Merkle tree file corresponding to the current file/directory + // to be used later for verifying Read/Walk. + merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The Merkle tree file should exist, as we succeeded in finding its + // dentry. If it's missing, it indicates an unexpected modification to + // the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + + // merkleReader needs to be cleaned up if any error occurs. IncRef will + // be called if a verity FD is successfully created. + defer merkleReader.DecRef(ctx) + + lowerFlags := lowerFD.StatusFlags() + lowerFDOpts := lowerFD.Options() + var merkleWriter *vfs.FileDescription + var parentMerkleWriter *vfs.FileDescription + + // Only open the Merkle tree files for write if in allowRuntimeEnable + // mode. + if d.fs.allowRuntimeEnable { + merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + // merkleWriter is cleaned up if any error occurs. IncRef will + // be called if a verity FD is created successfully. + defer merkleWriter.DecRef(ctx) + + if d.parent != nil { + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err + } + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) + } + } + + fd := &fileDescription{ + d: d, + lowerFD: lowerFD, + merkleReader: merkleReader, + merkleWriter: merkleWriter, + parentMerkleWriter: parentMerkleWriter, + isDir: d.isDir(), + } + + if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil { + return nil, err + } + lowerFD.IncRef() + merkleReader.IncRef() + if merkleWriter != nil { + merkleWriter.IncRef() + } + if parentMerkleWriter != nil { + parentMerkleWriter.IncRef() + } + return &fd.vfsfd, err } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. @@ -649,7 +809,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() @@ -660,8 +820,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -670,14 +830,14 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, err } lowerVD := d.lowerVD - return fs.vfsfs.VirtualFilesystem().ListxattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + return fs.vfsfs.VirtualFilesystem().ListXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, }, size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -686,20 +846,20 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", err } lowerVD := d.lowerVD - return fs.vfsfs.VirtualFilesystem().GetxattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + return fs.vfsfs.VirtualFilesystem().GetXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, }, &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { // Verity file system is read-only. return syserror.EROFS } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { // Verity file system is read-only. return syserror.EROFS } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index eedb5f484..fc5eabbca 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -22,16 +22,23 @@ package verity import ( + "fmt" + "strconv" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Name is the default filesystem name. @@ -50,8 +57,9 @@ const merkleOffsetInParentXattr = "user.merkle.offset" // whole file. For a directory, it's the size of all its children's root hashes. const merkleSizeXattr = "user.merkle.size" -// sizeOfInt32 is the size in bytes for a 32 bit integer in extended attributes. -const sizeOfInt32 = 4 +// sizeOfStringInt32 is the size for a 32 bit integer stored as string in +// extended attributes. The maximum value of a 32 bit integer is 10 digits. +const sizeOfStringInt32 = 10 // noCrashOnVerificationFailure indicates whether the sandbox should panic // whenever verification fails. If true, an error is returned instead of @@ -66,9 +74,13 @@ var noCrashOnVerificationFailure bool var verityMu sync.RWMutex // FilesystemType implements vfs.FilesystemType. +// +// +stateify savable type FilesystemType struct{} // filesystem implements vfs.FilesystemImpl. +// +// +stateify savable type filesystem struct { vfsfs vfs.Filesystem @@ -93,11 +105,13 @@ type filesystem struct { // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. - renameMu sync.RWMutex + renameMu sync.RWMutex `state:"nosave"` } // InternalFilesystemOptions may be passed as // vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem. +// +// +stateify savable type InternalFilesystemOptions struct { // RootMerkleFileName is the name of the verity root Merkle tree file. RootMerkleFileName string @@ -128,6 +142,16 @@ func (FilesystemType) Name() string { return Name } +// alertIntegrityViolation alerts a violation of integrity, which usually means +// unexpected modification to the file system is detected. In +// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. +func alertIntegrityViolation(err error, msg string) error { + if noCrashOnVerificationFailure { + return err + } + panic(msg) +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { iopts, ok := opts.InternalData.(InternalFilesystemOptions) @@ -141,6 +165,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // verity, and should not be exposed or connected. mopts := &vfs.MountOptions{ GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, } mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) if err != nil { @@ -197,15 +222,12 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } } else if err != nil { - // Failed to get dentry for the root Merkle file. This indicates - // an attack that removed/renamed the root Merkle file, or it's - // never generated. - if noCrashOnVerificationFailure { - fs.vfsfs.DecRef(ctx) - d.DecRef(ctx) - return nil, nil, err - } - panic("Failed to find root Merkle file") + // Failed to get dentry for the root Merkle file. This + // indicates an unexpected modification that removed/renamed + // the root Merkle file, or it's never generated. + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -243,6 +265,8 @@ func (fs *filesystem) Release(ctx context.Context) { } // dentry implements vfs.DentryImpl. +// +// +stateify savable type dentry struct { vfsd vfs.Dentry @@ -269,7 +293,7 @@ type dentry struct { // and dirents (if not nil) is a cache of dirents as returned by // directoryFDs representing this directory. children is protected by // dirMu. - dirMu sync.Mutex + dirMu sync.Mutex `state:"nosave"` children map[string]*dentry // lowerVD is the VirtualDentry in the underlying file system. @@ -413,6 +437,8 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { // FileDescription is a wrapper of the underlying lowerFD, with support to build // Merkle trees through the Linux fs-verity API to verify contents read from // lowerFD. +// +// +stateify savable type fileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -471,12 +497,247 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// generateMerkle generates a Merkle tree file for fd. If fd points to a file +// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root +// hash of the generated Merkle tree and the data size is returned. +// If fd points to a regular file, the data is the content of the file. If fd +// points to a directory, the data is all root hahes of its children, written +// to the Merkle tree file. +func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { + fdReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + merkleWriter := vfs.FileReadWriteSeeker{ + FD: fd.merkleWriter, + Ctx: ctx, + } + var rootHash []byte + var dataSize uint64 + + switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { + case linux.S_IFREG: + // For a regular file, generate a Merkle tree based on its + // content. + var err error + stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = stat.Size + + rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + case linux.S_IFDIR: + // For a directory, generate a Merkle tree based on the root + // hashes of its children that has already been written to the + // Merkle tree file. + merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = merkleStat.Size + + rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + default: + // TODO(b/167728857): Investigate whether and how we should + // enable other types of file. + return nil, 0, syserror.EINVAL + } + return rootHash, dataSize, nil +} + +// enableVerity enables verity features on fd by generating a Merkle tree file +// and stores its root hash in its parent directory's Merkle tree. +func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { + if !fd.d.fs.allowRuntimeEnable { + return 0, syserror.EPERM + } + + // Lock to prevent other threads performing enable or access the file + // while it's being enabled. + verityMu.Lock() + defer verityMu.Unlock() + + // In allowRuntimeEnable mode, the underlying fd and read/write fd for + // the Merkle tree file should have all been initialized. For any file + // or directory other than the root, the parent Merkle tree file should + // have also been initialized. + if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { + return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + } + + rootHash, dataSize, err := fd.generateMerkle(ctx) + if err != nil { + return 0, err + } + + if fd.parentMerkleWriter != nil { + stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return 0, err + } + + // Write the root hash of fd to the parent directory's Merkle + // tree file, as it should be part of the parent Merkle tree + // data. parentMerkleWriter is open with O_APPEND, so it + // should write directly to the end of the file. + if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { + return 0, err + } + + // Record the offset of the root hash of fd in parent directory's + // Merkle tree file. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleOffsetInParentXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return 0, err + } + } + + // Record the size of the data being hashed for fd. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleSizeXattr, + Value: strconv.Itoa(int(dataSize)), + }); err != nil { + return 0, err + } + fd.d.rootHash = append(fd.d.rootHash, rootHash...) + return 0, nil +} + +// measureVerity returns the root hash of fd, saved in args[2]. +func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + var metadata linux.DigestMetadata + + // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that + // verity is not enabled for the file. If allowRuntimeEnable is false, + // this is an integrity violation because all files should have verity + // enabled, in which case fd.d.rootHash should be set. + if len(fd.d.rootHash) == 0 { + if fd.d.fs.allowRuntimeEnable { + return 0, syserror.ENODATA + } + return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found") + } + + // The first part of VerityDigest is the metadata. + if _, err := metadata.CopyIn(t, verityDigest); err != nil { + return 0, err + } + if metadata.DigestSize < uint16(len(fd.d.rootHash)) { + return 0, syserror.EOVERFLOW + } + + // Populate the output digest size, since DigestSize is both input and + // output. + metadata.DigestSize = uint16(len(fd.d.rootHash)) + + // First copy the metadata. + if _, err := metadata.CopyOut(t, verityDigest); err != nil { + return 0, err + } + + // Now copy the root hash bytes to the memory after metadata. + _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash) + return 0, err +} + +func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { + f := int32(0) + + // All enabled files should store a root hash. This flag is not settable + // via FS_IOC_SETFLAGS. + if len(fd.d.rootHash) != 0 { + f |= linux.FS_VERITY_FL + } + + t := kernel.TaskFromContext(ctx) + _, err := primitive.CopyInt32Out(t, flags, f) + return 0, err +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + switch cmd := args[1].Uint(); cmd { + case linux.FS_IOC_ENABLE_VERITY: + return fd.enableVerity(ctx, uio) + case linux.FS_IOC_MEASURE_VERITY: + return fd.measureVerity(ctx, uio, args[2].Pointer()) + case linux.FS_IOC_GETFLAGS: + return fd.verityFlags(ctx, uio, args[2].Pointer()) + default: + // TODO(b/169682228): Investigate which ioctl commands should + // be allowed. + return 0, syserror.ENOSYS + } +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // No need to verify if the file is not enabled yet in + // allowRuntimeEnable mode. + if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 { + return fd.lowerFD.PRead(ctx, dst, offset, opts) + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return 0, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + dataReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */) + if err != nil { + return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err)) + } + return n, err +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { - return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) + return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { - return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) + return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) } diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go new file mode 100644 index 000000000..9b5641092 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -0,0 +1,349 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// rootMerkleFilename is the name of the root Merkle tree file. +const rootMerkleFilename = "root.verity" + +// maxDataSize is the maximum data size written to the file for test. +const maxDataSize = 100000 + +// newVerityRoot creates a new verity mount, and returns the root. The +// underlying file system is tmpfs. If the error is not nil, then cleanup +// should be called when the root is no longer needed. +func newVerityRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { + rand.Seed(time.Now().UnixNano()) + vfsObj := &vfs.VirtualFilesystem{} + if err := vfsObj.Init(ctx); err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) + } + + vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: InternalFilesystemOptions{ + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + AllowRuntimeEnable: true, + NoCrashOnVerificationFailure: true, + }, + }, + }) + if err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create verity root mount: %v", err) + } + root := mntns.Root() + return vfsObj, root, func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + }, nil +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + creds := auth.CredentialsFromContext(ctx) + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + + // Create the file in the underlying file system. + lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: linux.ModeRegular | mode, + }) + if err != nil { + return nil, 0, err + } + + // Generate random data to be written to the file. + dataSize := rand.Intn(maxDataSize) + 1 + data := make([]byte, dataSize) + rand.Read(data) + + // Write directly to the underlying FD, since verity FD is read-only. + n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + return nil, 0, err + } + + if n != int64(len(data)) { + return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data)) + } + + lowerFD.DecRef(ctx) + + // Now open the verity file descriptor. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular | mode, + }) + return fd, dataSize, err +} + +// corruptRandomBit randomly flips a bit in the file represented by fd. +func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { + // Flip a random bit in the underlying file. + randomPos := int64(rand.Intn(size)) + byteToModify := make([]byte, 1) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { + return fmt.Errorf("lowerFD.PRead failed: %v", err) + } + byteToModify[0] ^= 1 + if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil { + return fmt.Errorf("lowerFD.PWrite failed: %v", err) + } + return nil +} + +// TestOpen ensures that when a file is created, the corresponding Merkle tree +// file and the root Merkle tree file exist. +func TestOpen(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("Failed to create new verity root: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("Failed to create new file fd: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("Failed to open Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("Failed to open root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } +} + +// TestUntouchedFileSucceeds ensures that read from an untouched verity file +// succeeds after enabling verity for it. +func TestReadUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("Failed to create new verity root: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("Failed to create new file fd: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl failed: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead failed: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + +// TestReopenUntouchedFileSucceeds ensures that reopen an untouched verity file +// succeeds after enabling verity for it. +func TestReopenUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("Failed to create new verity root: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("Failed to create new file fd: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl failed: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } +} + +// TestModifiedFileFails ensures that read from a modified verity file fails. +func TestModifiedFileFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("Failed to create new verity root: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("Failed to create new file fd: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl failed: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("Open lowerFD failed: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("%v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded with modified file") + } +} + +// TestModifiedMerkleFails ensures that read from a verity file fails if the +// corresponding Merkle tree file is modified. +func TestModifiedMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("Failed to create new verity root: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("Failed to create new file fd: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl failed: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("Open lowerMerkleFD failed: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("Failed to get lowerMerkleFD stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("%v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } +} diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 07bf39fed..5bba9de0b 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -15,6 +15,7 @@ go_library( ], deps = [ "//pkg/context", + "//pkg/tcpip", "//pkg/tcpip/stack", ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index c0b4831d1..fbe6d6aa6 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,7 +15,10 @@ // Package inet defines semantics for IP stacks. package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // Stack represents a TCP/IP stack. type Stack interface { @@ -80,6 +83,12 @@ type Stack interface { // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. RestoreCleanupEndpoints([]stack.TransportEndpoint) + + // Forwarding returns if packet forwarding between NICs is enabled. + Forwarding(protocol tcpip.NetworkProtocolNumber) bool + + // SetForwarding enables or disables packet forwarding between NICs. + SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 9771f01fc..1779cc6f3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,7 +14,10 @@ package inet -import "gvisor.dev/gvisor/pkg/tcpip/stack" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) // TestStack is a dummy implementation of Stack for tests. type TestStack struct { @@ -26,6 +29,7 @@ type TestStack struct { TCPSendBufSize TCPBufferSize TCPSACKFlag bool Recovery TCPLossRecovery + IPForwarding bool } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -128,3 +132,14 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + return s.IPForwarding +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + s.IPForwarding = enable + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index d436daab4..083071b5e 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -69,8 +69,8 @@ go_template_instance( prefix = "socket", template = "//pkg/ilist:generic_list", types = { - "Element": "*SocketEntry", - "Linker": "*SocketEntry", + "Element": "*SocketRecordVFS1", + "Linker": "*SocketRecordVFS1", }, ) @@ -197,6 +197,7 @@ go_library( "gvisor.dev/gvisor/pkg/sentry/device", "gvisor.dev/gvisor/pkg/tcpip", ], + marshal = True, visibility = ["//:sandbox"], deps = [ ":uncaught_signal_go_proto", @@ -212,6 +213,8 @@ go_library( "//pkg/eventchannel", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", "//pkg/refs_vfs2", @@ -261,7 +264,6 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 2bc49483a..869e49ebc 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -57,6 +57,7 @@ go_library( "id_map_set.go", "user_namespace.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go index 0a58ba17c..4c32ee703 100644 --- a/pkg/sentry/kernel/auth/id.go +++ b/pkg/sentry/kernel/auth/id.go @@ -19,9 +19,13 @@ import ( ) // UID is a user ID in an unspecified user namespace. +// +// +marshal type UID uint32 // GID is a group ID in an unspecified user namespace. +// +// +marshal slice:GIDSlice type GID uint32 // In the root user namespace, user/group IDs have a 1-to-1 relationship with diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 5773244ac..0ec7344cd 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -111,8 +111,11 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() f.init() // Initialize table. + f.used = 0 for fd, d := range m { - f.setAll(fd, d.file, d.fileVFS2, d.flags) + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { + panic("VFS1 or VFS2 files set") + } // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the @@ -127,7 +130,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { } // drop drops the table reference. -func (f *FDTable) drop(file *fs.File) { +func (f *FDTable) drop(ctx context.Context, file *fs.File) { // Release locks. file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF}) @@ -145,14 +148,13 @@ func (f *FDTable) drop(file *fs.File) { d.InotifyEvent(ev, 0) // Drop the table reference. - file.DecRef(context.Background()) + file.DecRef(ctx) } // dropVFS2 drops the table reference. -func (f *FDTable) dropVFS2(file *vfs.FileDescription) { +func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the // entire file. - ctx := context.Background() err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) @@ -187,12 +189,6 @@ func (f *FDTable) DecRef(ctx context.Context) { }) } -// Size returns the number of file descriptor slots currently allocated. -func (f *FDTable) Size() int { - size := atomic.LoadInt32(&f.used) - return int(size) -} - // forEach iterates over all non-nil files in sorted order. // // It is the caller's responsibility to acquire an appropriate lock. @@ -279,7 +275,6 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -289,15 +284,25 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { - f.set(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.set(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.set(i, nil, FDFlags{}) // Zap entry. + f.set(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.set() that + // originally installed the file. Don't call f.drop() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -307,6 +312,7 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -334,7 +340,6 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -344,15 +349,25 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.getVFS2(i); d == nil { - f.setVFS2(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.setVFS2(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.setVFS2(i, nil, FDFlags{}) // Zap entry. + f.setVFS2(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.setVFS2() that + // originally installed the file. Don't call f.dropVFS2() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -362,6 +377,7 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -397,7 +413,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc } for fd < end { if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) if fd == f.next { // Update next search start position. f.next = fd + 1 @@ -413,40 +429,55 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error { - return f.newFDAt(ctx, fd, file, nil, flags) + df, _, err := f.newFDAt(ctx, fd, file, nil, flags) + if err != nil { + return err + } + if df != nil { + f.drop(ctx, df) + } + return nil } // NewFDAtVFS2 sets the file reference for the given FD. If there is an active // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error { - return f.newFDAt(ctx, fd, nil, file, flags) + _, dfVFS2, err := f.newFDAt(ctx, fd, nil, file, flags) + if err != nil { + return err + } + if dfVFS2 != nil { + f.dropVFS2(ctx, dfVFS2) + } + return nil } -func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error { +func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription, error) { if fd < 0 { // Don't accept negative FDs. - return syscall.EBADF + return nil, nil, syscall.EBADF } // Check the limit for the provided file. if limitSet := limits.FromContext(ctx); limitSet != nil { if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur { - return syscall.EMFILE + return nil, nil, syscall.EMFILE } } // Install the entry. f.mu.Lock() defer f.mu.Unlock() - f.setAll(fd, file, fileVFS2, flags) - return nil + + df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + return df, dfVFS2, nil } // SetFlags sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlags(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -462,14 +493,14 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { } // Update the flags. - f.set(fd, file, flags) + f.set(ctx, fd, file, flags) return nil } // SetFlagsVFS2 sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlagsVFS2(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -485,7 +516,7 @@ func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { } // Update the flags. - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) return nil } @@ -551,30 +582,6 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 { return fds } -// GetRefs returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefs(ctx context.Context) []*fs.File { - files := make([]*fs.File, 0, f.Size()) - f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - -// GetRefsVFS2 returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription { - files := make([]*vfs.FileDescription, 0, f.Size()) - f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - // Fork returns an independent FDTable. func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() @@ -582,11 +589,8 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable { f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. - switch { - case file != nil: - clone.set(fd, file, flags) - case fileVFS2 != nil: - clone.setVFS2(fd, fileVFS2, flags) + if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { + panic("VFS1 or VFS2 files set") } }) return clone @@ -595,13 +599,12 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable { // Remove removes an FD from and returns a non-file iff successful. // // N.B. Callers are required to use DecRef when they are done. -func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { +func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDescription) { if fd < 0 { return nil, nil } f.mu.Lock() - defer f.mu.Unlock() // Update current available position. if fd < f.next { @@ -617,24 +620,51 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { case orig2 != nil: orig2.IncRef() } + if orig != nil || orig2 != nil { - f.setAll(fd, nil, nil, FDFlags{}) // Zap entry. + orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. + } + f.mu.Unlock() + + if orig != nil { + f.drop(ctx, orig) } + if orig2 != nil { + f.dropVFS2(ctx, orig2) + } + return orig, orig2 } // RemoveIf removes all FDs where cond is true. func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { - f.mu.Lock() - defer f.mu.Unlock() + // TODO(gvisor.dev/issue/1624): Remove fs.File slice. + var files []*fs.File + var filesVFS2 []*vfs.FileDescription + f.mu.Lock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { - f.set(fd, nil, FDFlags{}) // Clear from table. + df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + if df != nil { + files = append(files, df) + } + if dfVFS2 != nil { + filesVFS2 = append(filesVFS2, dfVFS2) + } // Update current available position. if fd < f.next { f.next = fd } } }) + f.mu.Unlock() + + for _, file := range files { + f.drop(ctx, file) + } + + for _, file := range filesVFS2 { + f.dropVFS2(ctx, file) + } } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index e3f30ba2a..bf5460083 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -72,7 +72,7 @@ func TestFDTableMany(t *testing.T) { } i := int32(2) - fdTable.Remove(i) + fdTable.Remove(ctx, i) if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) } @@ -93,7 +93,7 @@ func TestFDTableOverLimit(t *testing.T) { t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) } else { for _, fd := range fds { - fdTable.Remove(fd) + fdTable.Remove(ctx, fd) } } @@ -150,13 +150,13 @@ func TestFDTable(t *testing.T) { t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref) } - ref, _ := fdTable.Remove(1) + ref, _ := fdTable.Remove(ctx, 1) if ref == nil { t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") } ref.DecRef(ctx) - if ref, _ := fdTable.Remove(1); ref != nil { + if ref, _ := fdTable.Remove(ctx, 1); ref != nil { t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") } }) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 6b8feb107..da79e6627 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -78,33 +79,37 @@ func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, boo return d.file, d.fileVFS2, d.flags, true } -// set sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// CurrentMaxFDs returns the number of file descriptors that may be stored in f +// without reallocation. +func (f *FDTable) CurrentMaxFDs() int { + slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + return len(slice) +} + +// set sets an entry for VFS1, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) { - f.setAll(fd, file, nil, flags) +func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) *fs.File { + dropFile, _ := f.setAll(ctx, fd, file, nil, flags) + return dropFile } -// setVFS2 sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setVFS2 sets an entry for VFS2, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) { - f.setAll(fd, nil, file, flags) +func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) *vfs.FileDescription { + _, dropFile := f.setAll(ctx, fd, nil, file, flags) + return dropFile } -// setAll sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setAll sets the file description referred to by fd to file/fileVFS2. If +// file/fileVFS2 are non-nil, it takes a reference on them. If setAll replaces +// an existing file description, it returns it with the FDTable's reference +// transferred to the caller, which must call f.drop/dropVFS2() on the returned +// file after unlocking f.mu. // // Precondition: mu must be held. -func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { +func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription) { if file != nil && fileVFS2 != nil { panic("VFS1 and VFS2 files set") } @@ -147,25 +152,25 @@ func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, } } - // Drop the table reference. + // Adjust used. + switch { + case orig == nil && desc != nil: + atomic.AddInt32(&f.used, 1) + case orig != nil && desc == nil: + atomic.AddInt32(&f.used, -1) + } + if orig != nil { switch { case orig.file != nil: if desc == nil || desc.file != orig.file { - f.drop(orig.file) + return orig.file, nil } case orig.fileVFS2 != nil: if desc == nil || desc.fileVFS2 != orig.fileVFS2 { - f.dropVFS2(orig.fileVFS2) + return nil, orig.fileVFS2 } } } - - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } + return nil, nil } diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index aad63aa99..d3e76ca7b 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -89,6 +89,10 @@ func (kcov *Kcov) TaskWork(t *Task) { kcov.mu.Lock() defer kcov.mu.Unlock() + if kcov.mode != linux.KCOV_TRACE_PC { + return + } + rw := &kcovReadWriter{ mf: kcov.mfp.MemoryFile(), fr: kcov.mappable.FileRange(), diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 402aa1718..d6c21adb7 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -220,13 +220,18 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // sockets is the list of all network sockets the system. Protected by - // extMu. + // sockets is the list of all network sockets in the system. + // Protected by extMu. + // TODO(gvisor.dev/issue/1624): Only used by VFS1. sockets socketList - // nextSocketEntry is the next entry number to use in sockets. Protected + // socketsVFS2 records all network sockets in the system. Protected by + // extMu. + socketsVFS2 map[*vfs.FileDescription]*SocketRecord + + // nextSocketRecord is the next entry number to use in sockets. Protected // by extMu. - nextSocketEntry uint64 + nextSocketRecord uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("failed to create sockfs mount: %v", err) } k.socketMount = socketMount + + k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) } return nil @@ -507,6 +514,10 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { + if VFS2Enabled { + return nil // Not relevant. + } + // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -533,11 +544,6 @@ func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { // // Precondition: Must be called with the kernel paused. func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - ts.mu.RLock() defer ts.mu.RUnlock() for t := range ts.Root.tids { @@ -556,6 +562,10 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { // TODO(gvisor.dev/issue/1663): Add save support for VFS2. + if VFS2Enabled { + return nil + } + return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil @@ -888,17 +898,18 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, opener fsbridge.Lookup fsContext *FSContext mntns *fs.MountNamespace + mntnsVFS2 *vfs.MountNamespace ) if VFS2Enabled { - mntnsVFS2 := args.MountNamespaceVFS2 + mntnsVFS2 = args.MountNamespaceVFS2 if mntnsVFS2 == nil { // MountNamespaceVFS2 adds a reference to the namespace, which is // transferred to the new process. mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2() } // Get the root directory from the MountNamespace. - root := args.MountNamespaceVFS2.Root() + root := mntnsVFS2.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. defer root.DecRef(ctx) @@ -1008,7 +1019,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, AbstractSocketNamespace: args.AbstractSocketNamespace, - MountNamespaceVFS2: args.MountNamespaceVFS2, + MountNamespaceVFS2: mntnsVFS2, ContainerID: args.ContainerID, } t, err := k.tasks.NewTask(config) @@ -1508,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// SocketEntry represents a socket recorded in Kernel.sockets. It implements +// SocketRecord represents a socket recorded in Kernel.socketsVFS2. +// +// +stateify savable +type SocketRecord struct { + k *Kernel + Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1. + SockVFS2 *vfs.FileDescription // Only used by VFS2. + ID uint64 // Socket table entry number. +} + +// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type SocketEntry struct { +type SocketRecordVFS1 struct { socketEntry - k *Kernel - Sock *refs.WeakRef - SockVFS2 *vfs.FileDescription - ID uint64 // Socket table entry number. + SocketRecord } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketEntry) WeakRefGone(context.Context) { +func (s *SocketRecordVFS1) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1532,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) { // Precondition: Caller must hold a reference to sock. func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{k: k, ID: id} + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecordVFS1{ + SocketRecord: SocketRecord{ + k: k, + ID: id, + }, + } s.Sock = refs.NewWeakRef(sock, s) k.sockets.PushBack(s) k.extMu.Unlock() @@ -1546,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) { // Precondition: Caller must hold a reference to sock. // // Note that the socket table will not hold a reference on the -// vfs.FileDescription, because we do not support weak refs on VFS2 files. +// vfs.FileDescription. func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{ + if _, ok := k.socketsVFS2[sock]; ok { + panic(fmt.Sprintf("Socket %p added twice", sock)) + } + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecord{ k: k, ID: id, SockVFS2: sock, } - k.sockets.PushBack(s) + k.socketsVFS2[sock] = s + k.extMu.Unlock() +} + +// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table. +func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + delete(k.socketsVFS2, sock) k.extMu.Unlock() } // ListSockets returns a snapshot of all sockets. // -// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef() +// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef() // to get a reference on a socket in the table. -func (k *Kernel) ListSockets() []*SocketEntry { +func (k *Kernel) ListSockets() []*SocketRecord { k.extMu.Lock() - var socks []*SocketEntry - for s := k.sockets.Front(); s != nil; s = s.Next() { - socks = append(socks, s) + var socks []*SocketRecord + if VFS2Enabled { + for _, s := range k.socketsVFS2 { + socks = append(socks, s) + } + } else { + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, &s.SocketRecord) + } } k.extMu.Unlock() return socks diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 449643118..99134e634 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index c410c96aa..67beb0ad6 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,6 +17,7 @@ package pipe import ( "fmt" + "io" "sync/atomic" "syscall" @@ -215,7 +216,7 @@ func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { if p.view.Size() == 0 { if !p.HasWriters() { // There are no writers, return EOF. - return 0, nil + return 0, io.EOF } return 0, syserror.ErrWouldBlock } diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 6d58b682f..f665920cb 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -145,9 +146,14 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume v = math.MaxInt32 // Silently truncate. } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + iocc := primitive.IOCopyContext{ + IO: io, + Ctx: ctx, + Opts: usermem.IOOpts{ + AddressSpaceActive: true, + }, + } + _, err := primitive.CopyInt32Out(&iocc, args[2].Pointer(), int32(v)) return 0, err default: return 0, syscall.ENOTTY diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index f223d59e1..f61039f5b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -67,6 +67,11 @@ func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlag return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (*VFSPipe) Allocate(context.Context, uint64, uint64, uint64) error { + return syserror.ESPIPE +} + // Open opens the pipe represented by vp. func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 50df179c3..1145faf13 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" @@ -999,18 +1000,15 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{ - IgnorePermissions: true, - }); err != nil { + if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } - _, err := t.CopyOut(data, word) + _, err := word.CopyOut(t, data) return err case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: - _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{ - IgnorePermissions: true, - }) + word := t.Arch().Native(uintptr(data)) + _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: @@ -1078,12 +1076,12 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if target.ptraceSiginfo == nil { return syserror.EINVAL } - _, err := t.CopyOut(data, target.ptraceSiginfo) + _, err := target.ptraceSiginfo.CopyOut(t, data) return err case linux.PTRACE_SETSIGINFO: var info arch.SignalInfo - if _, err := t.CopyIn(data, &info); err != nil { + if _, err := info.CopyIn(t, data); err != nil { return err } t.tg.pidns.owner.mu.RLock() @@ -1098,7 +1096,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if addr != linux.SignalSetSize { return syserror.EINVAL } - _, err := t.CopyOut(data, target.SignalMask()) + mask := target.SignalMask() + _, err := mask.CopyOut(t, data) return err case linux.PTRACE_SETSIGMASK: @@ -1106,7 +1105,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { return syserror.EINVAL } var mask linux.SignalSet - if _, err := t.CopyIn(data, &mask); err != nil { + if _, err := mask.CopyIn(t, data); err != nil { return err } // The target's task goroutine is stopped, so this is safe: @@ -1121,7 +1120,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_GETEVENTMSG: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() - _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg) + _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg) return err // PEEKSIGINFO is unimplemented but seems to have no users anywhere. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index cef1276ec..609ad3941 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -30,7 +30,7 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro if err != nil { return err } - _, err = t.CopyOut(data, n) + _, err = n.CopyOut(t, data) return err case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 413111faf..332bdb8e8 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -348,6 +348,16 @@ func (s *SyscallTable) LookupName(sysno uintptr) string { return fmt.Sprintf("sys_%d", sysno) // Unlikely. } +// LookupNo looks up a syscall number by name. +func (s *SyscallTable) LookupNo(name string) (uintptr, error) { + for i, syscall := range s.Table { + if syscall.Name == name { + return uintptr(i), nil + } + } + return 0, fmt.Errorf("syscall %q not found", name) +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 9d7a9128f..fce1064a7 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -341,12 +341,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { nt.SetClearTID(opts.ChildTID) } if opts.ChildSetTID { - // Can't use Task.CopyOut, which assumes AddressSpaceActive. - usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{}) + ctid := nt.ThreadID() + ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { - t.CopyOut(opts.ParentTID, ntid) + ntid.CopyOut(t, opts.ParentTID) } kind := ptraceCloneKindClone diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 9fa528384..d1136461a 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -126,7 +126,11 @@ func (t *Task) SyscallTable() *SyscallTable { // Preconditions: The caller must be running on the task goroutine, or t.mu // must be locked. func (t *Task) Stack() *arch.Stack { - return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} + return &arch.Stack{ + Arch: t.Arch(), + IO: t.MemoryManager(), + Bottom: usermem.Addr(t.Arch().Stack()), + } } // LoadTaskImage loads a specified file into a new TaskContext. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index b76f7f503..b400a8b41 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -248,7 +248,8 @@ func (*runExitMain) execute(t *Task) taskRunState { signaled := t.tg.exiting && t.tg.exitStatus.Signaled() t.tg.signalHandlers.mu.Unlock() if !signaled { - if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil { + zero := ThreadID(0) + if _, err := zero.CopyOut(t, t.cleartid); err == nil { t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1) } // If the CopyOut fails, there's nothing we can do. diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index 4b535c949..c80391475 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -87,7 +88,7 @@ func (t *Task) exitRobustList() { return } - next := rl.List + next := primitive.Uint64(rl.List) done := 0 var pendingLockAddr usermem.Addr if rl.ListOpPending != 0 { @@ -99,12 +100,12 @@ func (t *Task) exitRobustList() { // We traverse to the next element of the list before we // actually wake anything. This prevents the race where waking // this futex causes a modification of the list. - thisLockAddr := usermem.Addr(next + rl.FutexOffset) + thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset) // Try to decode the next element in the list before waking the // current futex. But don't check the error until after we've // woken the current futex. Linux does it in this order too - _, nextErr := t.CopyIn(usermem.Addr(next), &next) + _, nextErr := next.CopyIn(t, usermem.Addr(next)) // Wakeup the current futex if it's not pending. if thisLockAddr != pendingLockAddr { diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index aa3a573c0..8dc3fec90 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -141,7 +141,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error { region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) - _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) + _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found) if err == nil && bytes.Equal(expected, found) { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index feaa38596..ebdb83061 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -259,7 +259,11 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Set up the signal handler. If we have a saved signal mask, the signal // handler should run with the current mask, but sigreturn should restore // the saved one. - st := &arch.Stack{t.Arch(), mm, sp} + st := &arch.Stack{ + Arch: t.Arch(), + IO: mm, + Bottom: sp, + } mask := t.signalMask if t.haveSavedSignalMask { mask = t.savedSignalMask diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 2dbf86547..0141459e7 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -287,7 +288,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) - if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil { + if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil { t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) @@ -323,7 +324,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { type runVsyscallAfterPtraceEventSeccomp struct { addr usermem.Addr sysno uintptr - caller interface{} + caller marshal.Marshallable } func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { @@ -346,7 +347,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller) } -func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState { +func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller marshal.Marshallable) taskRunState { rval, ctrl, err := t.executeSyscall(sysno, args) if ctrl != nil { t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl) diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index 0cb86e390..ce134bf54 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -43,17 +44,6 @@ func (t *Task) Deactivate() { } } -// CopyIn copies a fixed-size value or slice of fixed-size values in from the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not readable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) { - return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyInBytes is a fast version of CopyIn if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -64,17 +54,6 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { }) } -// CopyOut copies a fixed-size value or slice of fixed-size values out to the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not writeable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) { - return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyOutBytes is a fast version of CopyOut if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -114,7 +93,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ var v []string for { argAddr := t.Arch().Native(0) - if _, err := t.CopyIn(addr, argAddr); err != nil { + if _, err := argAddr.CopyIn(t, addr); err != nil { return v, err } if t.Arch().Value(argAddr) == 0 { @@ -302,29 +281,29 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp }, nil } -// CopyContextWithOpts wraps a task to allow copying memory to and from the -// task memory with user specified usermem.IOOpts. -type CopyContextWithOpts struct { +// copyContext implements marshal.CopyContext. It wraps a task to allow copying +// memory to and from the task memory with custom usermem.IOOpts. +type copyContext struct { *Task opts usermem.IOOpts } -// AsCopyContextWithOpts wraps the task and returns it as CopyContextWithOpts. -func (t *Task) AsCopyContextWithOpts(opts usermem.IOOpts) *CopyContextWithOpts { - return &CopyContextWithOpts{t, opts} +// AsCopyContext wraps the task and returns it as CopyContext. +func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { + return ©Context{t, opts} } // CopyInString copies a string in from the task's memory. -func (t *CopyContextWithOpts) CopyInString(addr usermem.Addr, maxLen int) (string, error) { +func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) } // CopyInBytes copies task memory into dst from an IO context. -func (t *CopyContextWithOpts) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { +func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { return t.MemoryManager().CopyIn(t, addr, dst, t.opts) } // CopyOutBytes copies src into task memoryfrom an IO context. -func (t *CopyContextWithOpts) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { +func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { return t.MemoryManager().CopyOut(t, addr, src, t.opts) } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 872e1a82d..5ae5906e8 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -36,6 +36,8 @@ import ( const TasksLimit = (1 << 16) // ThreadID is a generic thread identifier. +// +// +marshal type ThreadID int32 // String returns a decimal representation of the ThreadID. diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 15c88aa7c..c69b62db9 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -122,7 +122,7 @@ func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch if err != nil { return nil, err } - return &arch.Stack{a, m, ar.End}, nil + return &arch.Stack{Arch: a, IO: m, Bottom: ar.End}, nil } const ( @@ -247,20 +247,20 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(args.Filename) - if err != nil { + if _, err := stack.PushNullTerminatedByteSlice([]byte(args.Filename)); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } + execfn := stack.Bottom // Push 16 random bytes on the stack which AT_RANDOM will point to. var b [16]byte if _, err := rand.Read(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux()) } - random, err := stack.Push(b) - if err != nil { + if _, err = stack.PushNullTerminatedByteSlice(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux()) } + random := stack.Bottom c := auth.CredentialsFromContext(ctx) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 3e85964e4..8c9f11cce 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -242,7 +242,7 @@ type MemoryManager struct { // +stateify savable type vma struct { // mappable is the virtual memory object mapped by this vma. If mappable is - // nil, the vma represents a private anonymous mapping. + // nil, the vma represents an anonymous mapping. mappable memmap.Mappable // off is the offset into mappable at which this vma begins. If mappable is diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index fdc308542..acac3d357 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -51,7 +51,8 @@ func TestUsageASUpdates(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, + Length: 2 * usermem.PageSize, + Private: true, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index f4c93baeb..2dbe5b751 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -136,9 +136,12 @@ func (m *SpecialMappable) Length() uint64 { // NewSharedAnonMappable returns a SpecialMappable that implements the // semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero. // -// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux -// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should -// do the same to get non-zero device and inode IDs. +// TODO(gvisor.dev/issue/1624): Linux uses an ephemeral file created by +// mm/shmem.c:shmem_zero_setup(), and VFS2 does something analogous. VFS1 uses +// a SpecialMappable instead, incorrectly getting device and inode IDs of zero +// and causing memory for shared anonymous mappings to be allocated up-front +// instead of on first touch; this is to avoid exacerbating the fs.MountSource +// leak (b/143656263). Delete this function along with VFS1. func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) { if length == 0 { return nil, syserror.EINVAL diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 4c9a575e7..a2555ba1a 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -93,18 +92,6 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme } } else { opts.Offset = 0 - if !opts.Private { - if opts.MappingIdentity != nil { - return 0, syserror.EINVAL - } - m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) - if err != nil { - return 0, err - } - defer m.DecRef(ctx) - opts.MappingIdentity = m - opts.Mappable = m - } } if opts.Addr.RoundDown() != opts.Addr { diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 3970dd81d..323837fb1 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -9,12 +9,12 @@ go_library( "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", - "bluepill_amd64.s", "bluepill_amd64_unsafe.go", "bluepill_arm64.go", "bluepill_arm64.s", "bluepill_arm64_unsafe.go", "bluepill_fault.go", + "bluepill_impl_amd64.s", "bluepill_unsafe.go", "context.go", "filters_amd64.go", @@ -81,3 +81,11 @@ go_test( "//pkg/usermem", ], ) + +genrule( + name = "bluepill_impl_amd64", + srcs = ["bluepill_amd64.s"], + outs = ["bluepill_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 2bc34a435..025ea93b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -19,11 +19,6 @@ // This is guaranteed to be zero. #define VCPU_CPU 0x0 -// CPU_SELF is the self reference in ring0's percpu. -// -// This is guaranteed to be zero. -#define CPU_SELF 0x0 - // Context offsets. // // Only limited use of the context is done in the assembly stub below, most is @@ -44,7 +39,7 @@ begin: LEAQ VCPU_CPU(AX), BX BYTE CLI; check_vcpu: - MOVQ CPU_SELF(GS), CX + MOVQ ENTRY_CPU_SELF(GS), CX CMPQ BX, CX JE right_vCPU wrong_vcpu: diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ae813e24e..d46946402 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -156,15 +156,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. pageTables := pagetables.New(newAllocator()) - applyPhysicalRegions(func(pr physicalRegion) bool { - // Map the kernel in the upper half. - pageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. - }) + k.machine.mapUpperHalf(pageTables) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 372a4cbd7..75da253c5 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -155,7 +155,7 @@ func (m *machine) newVCPU() *vCPU { fd: int(fd), machine: m, } - c.CPU.Init(&m.kernel, c) + c.CPU.Init(&m.kernel, c.id, c) m.vCPUsByID[c.id] = c // Ensure the signal mask is correct. @@ -183,9 +183,6 @@ func newMachine(vm int) (*machine, error) { // Create the machine. m := &machine{fd: vm} m.available.L = &m.mu - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }) // Pull the maximum vCPUs. maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) @@ -197,6 +194,9 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + m.kernel.Init(ring0.KernelOpts{ + PageTables: pagetables.New(newAllocator()), + }, m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -219,15 +219,9 @@ func newMachine(vm int) (*machine, error) { pagetables.MapOpts{AccessType: usermem.AnyAccess}, pr.physical) - // And keep everything in the upper half. - m.kernel.PageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. }) + m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -365,6 +359,11 @@ func (m *machine) Destroy() { // Get gets an available vCPU. // // This will return with the OS thread locked. +// +// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points +// to the vCPU in which the OS thread TID is running. So if Get() returns with +// the corrent context in guest, the vCPU of it must be the same as what +// Get() returns. func (m *machine) Get() *vCPU { m.mu.RLock() runtime.LockOSThread() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index acc823ba6..6849ab113 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -144,6 +144,7 @@ func (c *vCPU) initArchState() error { // Set the entrypoint for the kernel. kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer()) kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer()) + kernelUserRegs.RSP = c.StackTop() kernelUserRegs.RFLAGS = ring0.KernelFlagsSet // Set the system registers. @@ -345,3 +346,43 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } + +var execRegions []region + +func init() { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { + return + } + + if vr.accessType.Execute { + execRegions = append(execRegions, vr.region) + } + }) +} + +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + for _, r := range execRegions { + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossilbe translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) + } + for start, end := range m.kernel.EntryRegions() { + regionLen := end - start + physical, length, ok := translateToPhysical(start) + if !ok || length < regionLen { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|start), + regionLen, + pagetables.MapOpts{AccessType: usermem.ReadWrite}, + physical) + } +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 9db171af9..2df762991 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -19,6 +19,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -48,6 +49,18 @@ const ( poolPCIDs = 8 ) +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + applyPhysicalRegions(func(pr physicalRegion) bool { + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|pr.virtual), + pr.length, + pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pr.physical) + + return true // Keep iterating. + }) +} + // Get all read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { var rdonlyRegions []region diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 905712076..537419657 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error { } // tcr_el1 - data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1 reg.id = _KVM_ARM64_REGS_TCR_EL1 if err := c.setOneRegister(®); err != nil { return err @@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error { c.SetTtbr0Kvm(uintptr(data)) // ttbr1_el1 - data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1) reg.id = _KVM_ARM64_REGS_TTBR1_EL1 if err := c.setOneRegister(®); err != nil { diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index e04165fbf..fc43cc3c0 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/hostcpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go index 1e07cfd0d..b0970e356 100644 --- a/pkg/sentry/platform/ptrace/filters.go +++ b/pkg/sentry/platform/ptrace/filters.go @@ -24,10 +24,9 @@ import ( // SyscallFilters returns syscalls made exclusively by the ptrace platform. func (*PTrace) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - unix.SYS_GETCPU: {}, - unix.SYS_SCHED_SETAFFINITY: {}, - syscall.SYS_PTRACE: {}, - syscall.SYS_TGKILL: {}, - syscall.SYS_WAIT4: {}, + unix.SYS_GETCPU: {}, + syscall.SYS_PTRACE: {}, + syscall.SYS_TGKILL: {}, + syscall.SYS_WAIT4: {}, } } diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index e1d54d8a2..812ab80ef 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -518,11 +518,6 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { } defer c.interrupt.Disable() - // Ensure that the CPU set is bound appropriately; this makes the - // emulation below several times faster, presumably by avoiding - // interprocessor wakeups and by simplifying the schedule. - t.bind() - // Set registers. if err := t.setRegs(regs); err != nil { panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err)) diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 84b699f0d..020bbda79 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -201,7 +201,7 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi seccomp.RuleSet{ Rules: seccomp.SyscallRules{ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + {seccomp.EqualTo(linux.ARCH_SET_CPUID), seccomp.EqualTo(0)}, }, }, Action: linux.SECCOMP_RET_ALLOW, diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 2ce528601..8548853da 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -80,9 +80,9 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_CLONE: []seccomp.Rule{ // Allow creation of new subprocesses (used by the master). - {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.CLONE_FILES | syscall.SIGKILL)}, // Allow creation of new threads within a single address space (used by addresss spaces). - {seccomp.AllowValue( + {seccomp.EqualTo( syscall.CLONE_FILES | syscall.CLONE_FS | syscall.CLONE_SIGHAND | @@ -97,14 +97,14 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.PR_SET_PDEATHSIG), seccomp.EqualTo(syscall.SIGKILL)}, }, syscall.SYS_GETPPID: {}, // For the stub to stop itself (all). syscall.SYS_GETPID: {}, syscall.SYS_KILL: []seccomp.Rule{ - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SIGSTOP)}, }, // Injected to support the address space operations. @@ -115,7 +115,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }) } rules = appendArchSeccompRules(rules, defaultAction) - instrs, err := seccomp.BuildProgram(rules, defaultAction) + instrs, err := seccomp.BuildProgram(rules, defaultAction, defaultAction) if err != nil { return nil, err } diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 245b20722..533e45497 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,29 +18,12 @@ package ptrace import ( - "sync/atomic" "syscall" "unsafe" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/hostcpu" - "gvisor.dev/gvisor/pkg/sync" ) -// maskPool contains reusable CPU masks for setting affinity. Unfortunately, -// runtime.NumCPU doesn't actually record the number of CPUs on the system, it -// just records the number of CPUs available in the scheduler affinity set at -// startup. This may a) change over time and b) gives a number far lower than -// the maximum indexable CPU. To prevent lots of allocation in the hot path, we -// use a pool to store large masks that we can reuse during bind. -var maskPool = sync.Pool{ - New: func() interface{} { - const maxCPUs = 1024 // Not a hard limit; see below. - return make([]uintptr, maxCPUs/64) - }, -} - // unmaskAllSignals unmasks all signals on the current thread. // //go:nosplit @@ -49,47 +32,3 @@ func unmaskAllSignals() syscall.Errno { _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0) return errno } - -// setCPU sets the CPU affinity. -func (t *thread) setCPU(cpu uint32) error { - mask := maskPool.Get().([]uintptr) - n := int(cpu / 64) - v := uintptr(1 << uintptr(cpu%64)) - if n >= len(mask) { - // See maskPool note above. We've actually exceeded the number - // of available cores. Grow the mask and return it. - mask = make([]uintptr, n+1) - } - mask[n] |= v - if _, _, errno := syscall.RawSyscall( - unix.SYS_SCHED_SETAFFINITY, - uintptr(t.tid), - uintptr(len(mask)*8), - uintptr(unsafe.Pointer(&mask[0]))); errno != 0 { - return errno - } - mask[n] &^= v - maskPool.Put(mask) - return nil -} - -// bind attempts to ensure that the thread is on the same CPU as the current -// thread. This provides no guarantees as it is fundamentally a racy operation: -// CPU sets may change and we may be rescheduled in the middle of this -// operation. As a result, no failures are reported. -// -// Precondition: the current runtime thread should be locked. -func (t *thread) bind() { - currentCPU := hostcpu.GetCPU() - - if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { - // Set the affinity on the thread and save the CPU for next - // round; we don't expect CPUs to bounce around too frequently. - // - // (It's worth noting that we could move CPUs between this point - // and when the tracee finishes executing. But that would be - // roughly the status quo anyways -- we're just maximizing our - // chances of colocation, not guaranteeing it.) - t.setCPU(currentCPU) - } -} diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 9c6c2cf5c..f617519fa 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -76,15 +76,42 @@ type KernelOpts struct { type KernelArchState struct { KernelOpts + // cpuEntries is array of kernelEntry for all cpus + cpuEntries []kernelEntry + // globalIDT is our set of interrupt gates. - globalIDT idt64 + globalIDT *idt64 } -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { +// kernelEntry contains minimal CPU-specific arch state +// that can be mapped at the upper of the address space. +// Malicious APP might steal info from it via CPU bugs. +type kernelEntry struct { // stack is the stack used for interrupts on this CPU. stack [256]byte + // scratch space for temporary usage. + scratch0 uint64 + scratch1 uint64 + + // stackTop is the top of the stack. + stackTop uint64 + + // cpuSelf is back reference to CPU. + cpuSelf *CPU + + // kernelCR3 is the cr3 used for sentry kernel. + kernelCR3 uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { // errorCode is the error code from the last exception. errorCode uintptr @@ -97,11 +124,7 @@ type CPUArchState struct { // exception. errorType uintptr - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 + *kernelEntry } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 0e2ab716c..508236e46 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -77,6 +77,9 @@ type CPUArchState struct { // lazyVFP is the value of cpacr_el1. lazyVFP uintptr + + // appASID is the asid value of guest application. + appASID uintptr } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go index 7fa43c2f5..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -36,12 +36,15 @@ func sysenter() // This must be called prior to sysret/iret. func swapgs() +// jumpToKernel jumps to the kernel version of the current RIP. +func jumpToKernel() + // sysret returns to userspace from a system call. // // The return code is the vector that interrupted execution. // // See stubs.go for a note regarding the frame size of this function. -func sysret(*CPU, *arch.Registers) Vector +func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // "iret is the cadillac of CPL switching." // @@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector // iret is nearly identical to sysret, except an iret is used to fully restore // all user state. This must be called in cases where all registers need to be // restored. -func iret(*CPU, *arch.Registers) Vector +func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // exception is the generic exception entry. // diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s index 02df38331..82689a194 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_amd64.s @@ -63,6 +63,15 @@ MOVQ offset+PTRACE_RSI(reg), SI; \ MOVQ offset+PTRACE_RDI(reg), DI; +// WRITE_CR3() writes the given CR3 value. +// +// The code corresponds to: +// +// mov %rax, %cr3 +// +#define WRITE_CR3() \ + BYTE $0x0f; BYTE $0x22; BYTE $0xd8; + // SWAP_GS swaps the kernel GS (CPU). #define SWAP_GS() \ BYTE $0x0F; BYTE $0x01; BYTE $0xf8; @@ -75,15 +84,9 @@ #define SYSRET64() \ BYTE $0x48; BYTE $0x0f; BYTE $0x07; -// LOAD_KERNEL_ADDRESS loads a kernel address. -#define LOAD_KERNEL_ADDRESS(from, to) \ - MOVQ from, to; \ - ORQ ·KernelStartAddress(SB), to; - // LOAD_KERNEL_STACK loads the kernel stack. -#define LOAD_KERNEL_STACK(from) \ - LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \ - LEAQ CPU_STACK_TOP(SP), SP; +#define LOAD_KERNEL_STACK(entry) \ + MOVQ ENTRY_STACK_TOP(entry), SP; // See kernel.go. TEXT ·Halt(SB),NOSPLIT,$0 @@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0 SWAP_GS() RET +// jumpToKernel changes execution to the kernel address space. +// +// This works by changing the return value to the kernel version. +TEXT ·jumpToKernel(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + ORQ ·KernelStartAddress(SB), AX // Future return value. + MOVQ AX, 0(SP) + RET + // See entry_amd64.go. TEXT ·sysret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + // save SP AX userCR3 on the kernel stack. + MOVQ CPU_ENTRY(BX), BX + LOAD_KERNEL_STACK(BX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_RAX(AX) + PUSHQ CX + // Restore user register state. REGISTERS_LOAD(AX, 0) MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET. MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET. - MOVQ PTRACE_RSP(AX), SP // Restore the stack directly. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + + // restore userCR3, AX, SP. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. + POPQ SP // Restore SP. SYSRET64() // See entry_amd64.go. TEXT ·iret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) // Build an IRET frame & restore state. + MOVQ CPU_ENTRY(BX), BX LOAD_KERNEL_STACK(BX) - MOVQ PTRACE_SS(AX), BX; PUSHQ BX - MOVQ PTRACE_RSP(AX), CX; PUSHQ CX - MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX - MOVQ PTRACE_CS(AX), DI; PUSHQ DI - MOVQ PTRACE_RIP(AX), SI; PUSHQ SI - REGISTERS_LOAD(AX, 0) // Restore most registers. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + PUSHQ PTRACE_SS(AX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_FLAGS(AX) + PUSHQ PTRACE_CS(AX) + PUSHQ PTRACE_RIP(AX) + PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack. + PUSHQ CX // Save userCR3 on kernel stack. + REGISTERS_LOAD(AX, 0) // Restore most registers. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. IRET() // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 // See iret, above. - MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX - MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX - MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI - MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI - REGISTERS_LOAD(GS, CPU_REGISTERS) - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + PUSHQ CPU_REGISTERS+PTRACE_SS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RSP(AX) + PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX) + PUSHQ CPU_REGISTERS+PTRACE_CS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RIP(AX) + REGISTERS_LOAD(AX, CPU_REGISTERS) + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX IRET() // See entry_amd64.go. TEXT ·Start(SB),NOSPLIT,$0 - LOAD_KERNEL_STACK(AX) // Set the stack. PUSHQ $0x0 // Previous frame pointer. MOVQ SP, BP // Set frame pointer. PUSHQ AX // First argument (CPU). @@ -164,44 +202,54 @@ TEXT ·sysenter(SB),NOSPLIT,$0 user: SWAP_GS() - XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks. - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. - MOVQ BX, PTRACE_RAX(AX) // Save everything else. - MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ CX, PTRACE_RIP(AX) MOVQ R11, PTRACE_FLAGS(AX) - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + MOVQ SP, PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value. + MOVQ CX, PTRACE_RAX(AX) // Save everything else. + MOVQ CX, PTRACE_ORIGRAX(AX) + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks. + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. // Return to the kernel, where the frame is: // - // vector (sp+24) + // vector (sp+32) + // userCR3 (sp+24) // regs (sp+16) // cpu (sp+8) // vcpu.Switch (sp+0) // - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ $Syscall, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ $Syscall, 32(SP) // Output vector. RET kernel: // We can't restore the original stack, but we can access the registers // in the CPU state directly. No need for temporary juggling. - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ AX, ENTRY_SCRATCH0(GS) + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), BX + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ AX // First argument (vCPU). CALL ·kernelSyscall(SB) // Call the trampoline. POPQ AX // Pop vCPU. @@ -237,9 +285,14 @@ user: SWAP_GS() ADDQ $-8, SP // Adjust for flags. MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ). - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs. + PUSHQ AX // Save user AX on stack. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX. + POPQ BX // Restore original AX. MOVQ BX, PTRACE_RAX(AX) // Save it. MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX) @@ -249,34 +302,36 @@ user: MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX) // Copy out and return. + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. MOVQ 0(SP), BX // Load vector. MOVQ 8(SP), CX // Load error code. - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version). - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ CX, CPU_ERROR_CODE(GS) // Set error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. - MOVQ BX, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version). + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ CX, CPU_ERROR_CODE(AX) // Set error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. + MOVQ BX, 32(SP) // Output vector. RET kernel: // As per above, we can save directly. - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS) + PUSHQ AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + POPQ BX + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX) // Set the error code and adjust the stack. - MOVQ 8(SP), AX // Load the error code. - MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ 8(SP), BX // Load the error code. + MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. MOVQ 0(SP), BX // BX contains the vector. - ADDQ $48, SP // Drop the exception frame. // Call the exception trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ BX // Second argument (vector). PUSHQ AX // First argument (vCPU). CALL ·kernelException(SB) // Call the trampoline. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 9d29b7168..5f63cbd45 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -27,7 +27,9 @@ // ERET returns using the ELR and SPSR for the current exception level. #define ERET() \ - WORD $0xd69f03e0 + WORD $0xd69f03e0; \ + DSB $7; \ + ISB $15; // RSV_REG is a register that holds el1 information temporarily. #define RSV_REG R18_PLATFORM @@ -300,17 +302,23 @@ // SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. #define SWITCH_TO_APP_PAGETABLE(from) \ - MOVD CPU_TTBR0_APP(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD CPU_APP_ASID(from), R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set) ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; // SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. #define SWITCH_TO_KVM_PAGETABLE(from) \ - MOVD CPU_TTBR0_KVM(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD $1, R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ @@ -326,23 +334,20 @@ #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. - SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 - MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. - REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + WORD $0xd538d092; \ // MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context. MOVD RSV_REG_APP, R20; \ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ ADD $16, RSP, RSP; \ MOVD RSV_REG, PTRACE_R18(R20); \ MOVD RSV_REG_APP, PTRACE_R9(R20); \ - MOVD R20, RSV_REG_APP; \ WORD $0xd5384003; \ // MRS SPSR_EL1, R3 - MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MOVD R3, PTRACE_PSTATE(R20); \ MRS ELR_EL1, R3; \ - MOVD R3, PTRACE_PC(RSV_REG_APP); \ + MOVD R3, PTRACE_PC(R20); \ WORD $0xd5384103; \ // MRS SP_EL0, R3 - MOVD R3, PTRACE_SP(RSV_REG_APP); + MOVD R3, PTRACE_SP(R20); // KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1. #define KERNEL_ENTRY_FROM_EL1 \ @@ -357,6 +362,13 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// storeAppASID writes the application's asid value. +TEXT ·storeAppASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + MRS TPIDR_EL1, RSV_REG + MOVD R1, CPU_APP_ASID(RSV_REG) + RET + // Halt halts execution. TEXT ·Halt(SB),NOSPLIT,$0 // Clear bluepill. @@ -414,7 +426,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8 MOVD R8, ret+0(FP) RET -#define STACK_FRAME_SIZE 16 +#define STACK_FRAME_SIZE 32 // kernelExitToEl0 is the entrypoint for application in guest_el0. // Prepare the vcpu environment for container application. @@ -458,15 +470,16 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 SUB $STACK_FRAME_SIZE, RSP, RSP STP (RSV_REG, RSV_REG_APP), 16*0(RSP) + STP (R0, R1), 16*1(RSP) WORD $0xd538d092 //MRS TPIDR_EL1, R18 SWITCH_TO_APP_PAGETABLE(RSV_REG) + LDP 16*1(RSP), (R0, R1) LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) ADD $STACK_FRAME_SIZE, RSP, RSP - ISB $15 ERET() // kernelExitToEl1 is the entrypoint for sentry in guest_el1. @@ -482,6 +495,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP + SWITCH_TO_KVM_PAGETABLE(RSV_REG) + MRS TPIDR_EL1, RSV_REG + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 549f3d228..9742308d8 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,7 +24,10 @@ go_binary( "defs_impl_arm64.go", "main.go", ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], + visibility = [ + "//pkg/sentry/platform/kvm:__pkg__", + "//pkg/sentry/platform/ring0:__pkg__", + ], deps = [ "//pkg/cpuid", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 021693791..264be23d3 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -19,8 +19,8 @@ package ring0 // N.B. that constraints on KernelOpts must be satisfied. // //go:nosplit -func (k *Kernel) Init(opts KernelOpts) { - k.init(opts) +func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { + k.init(opts, maxCPUs) } // Halt halts execution. @@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) { // kernelSyscall is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) { // kernelException is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) { // Init initializes a new CPU. // // Init allows embedding in other objects. -func (c *CPU) Init(k *Kernel, hooks Hooks) { - c.self = c // Set self reference. - c.kernel = k // Set kernel reference. - c.init() // Perform architectural init. +func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { + c.self = c // Set self reference. + c.kernel = k // Set kernel reference. + c.init(cpuID) // Perform architectural init. // Require hooks. if hooks != nil { diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index d37981dbf..3a9dff4cc 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -18,13 +18,42 @@ package ring0 import ( "encoding/binary" + "reflect" + + "gvisor.dev/gvisor/pkg/usermem" ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables + entrySize := reflect.TypeOf(kernelEntry{}).Size() + var ( + entries []kernelEntry + padding = 1 + ) + for { + entries = make([]kernelEntry, maxCPUs+padding-1) + totalSize := entrySize * uintptr(maxCPUs+padding-1) + addr := reflect.ValueOf(&entries[0]).Pointer() + if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize { + // The runtime forces power-of-2 alignment for allocations, and we are therefore + // safe once the first address is aligned and the chunk is at least a full page. + break + } + padding = padding << 1 + } + k.cpuEntries = entries + + k.globalIDT = &idt64{} + if reflect.TypeOf(idt64{}).Size() != usermem.PageSize { + panic("Size of globalIDT should be PageSize") + } + if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 { + panic("Allocated globalIDT should be page aligned") + } + // Setup the IDT, which is uniform. for v, handler := range handlers { // Allow Breakpoint and Overflow to be called from all @@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) { } } +func (k *Kernel) EntryRegions() map[uintptr]uintptr { + regions := make(map[uintptr]uintptr) + + addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer() + size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries)) + end, _ := usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + addr = reflect.ValueOf(k.globalIDT).Pointer() + size = reflect.TypeOf(idt64{}).Size() + end, _ = usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + return regions +} + // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { + c.kernelEntry = &c.kernel.cpuEntries[cpuID] + c.cpuSelf = c // Null segment. c.gdt[0].setNull() @@ -65,6 +112,7 @@ func (c *CPU) init() { // Set the kernel stack pointer in the TSS (virtual address). stackAddr := c.StackTop() + c.stackTop = stackAddr c.tss.rsp0Lo = uint32(stackAddr) c.tss.rsp0Hi = uint32(stackAddr >> 32) c.tss.ist1Lo = uint32(stackAddr) @@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) - kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID) + c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)) // Sanitize registers. regs := switchOpts.Registers @@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. - jumpToKernel() // Switch to upper half. - writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { - vector = iret(c, regs) + vector = iret(c, regs, uintptr(userCR3)) } else { - vector = sysret(c, regs) + vector = sysret(c, regs, uintptr(userCR3)) } - writeCR3(uintptr(kernelCR3)) // Return to kernel address space. - jumpToUser() // Return to lower half. SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return @@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { //go:nosplit func start(c *CPU) { // Save per-cpu & FS segment. - WriteGS(kernelAddr(c)) + WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) // Initialize floating point. diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index d0afa1aaa..0ca98a7c7 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,13 +25,13 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables } // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { // Set the kernel stack pointer(virtual address). c.registers.Sp = uint64(c.StackTop()) @@ -53,6 +53,11 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + storeAppASID(uintptr(switchOpts.UserASID)) + if switchOpts.Flush { + FlushTlbAll() + } + regs := switchOpts.Registers regs.Pstate &= ^uint64(PsrFlagsClear) diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go index ca968a036..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/sentry/platform/ring0/lib_amd64.go @@ -61,21 +61,9 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) -// writeCR3 writes the CR3 value. -func writeCR3(phys uintptr) - -// readCR3 reads the current CR3 value. -func readCR3() uintptr - // readCR2 reads the current CR2 value. func readCR2() uintptr -// jumpToKernel jumps to the kernel version of the current RIP. -func jumpToKernel() - -// jumpToUser jumps to the user version of the current RIP. -func jumpToUser() - // fninit initializes the floating point unit. func fninit() diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s index 75d742750..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/sentry/platform/ring0/lib_amd64.s @@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; // WRMSR RET -// jumpToUser changes execution to the user address. -// -// This works by changing the return value to the user version. -TEXT ·jumpToUser(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - NOTQ BX - ANDQ BX, SP // Switch the stack. - ANDQ BX, BP // Switch the frame pointer. - ANDQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// jumpToKernel changes execution to the kernel address space. -// -// This works by changing the return value to the kernel version. -TEXT ·jumpToKernel(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - ORQ BX, SP // Switch the stack. - ORQ BX, BP // Switch the frame pointer. - ORQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// writeCR3 writes the given CR3 value. -// -// The code corresponds to: -// -// mov %rax, %cr3 -// -TEXT ·writeCR3(SB),NOSPLIT,$0-8 - MOVQ cr3+0(FP), AX - BYTE $0x0f; BYTE $0x22; BYTE $0xd8; - RET - -// readCR3 reads the current CR3 value. -// -// The code corresponds to: -// -// mov %cr3, %rax -// -TEXT ·readCR3(SB),NOSPLIT,$0-8 - BYTE $0x0f; BYTE $0x20; BYTE $0xd8; - MOVQ AX, ret+0(FP) - RET - // readCR2 reads the current CR2 value. // // The code corresponds to: diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 00e52c8af..2f1abcb0f 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -16,6 +16,15 @@ package ring0 +// storeAppASID writes the application's asid value. +func storeAppASID(asid uintptr) + +// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. +func LocalFlushTlbAll() + +// FlushTlbAll flush all tlb. +func FlushTlbAll() + // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 86bfbe46f..8aabf7d0e 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,20 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 + DSB $6 // dsb(nshst) + WORD $0xd508871f // __tlbi(vmalle1) + DSB $7 // dsb(nsh) + ISB $15 + RET + +TEXT ·FlushTlbAll(SB),NOSPLIT,$0 + DSB $10 // dsb(ishst) + WORD $0xd508831f // __tlbi(vmalle1is) + DSB $11 // dsb(ish) + ISB $15 + RET + TEXT ·GetTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go index b8ab120a0..bcb73cb31 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -30,11 +30,18 @@ func Emit(w io.Writer) { c := &CPU{} fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + + e := &kernelEntry{} + fmt.Fprintf(w, "\n// CPU entry offsets.\n") + fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_SCRATCH1 0x%02x\n", reflect.ValueOf(&e.scratch1).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index f3de962f0..1d86b4bcf 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -41,6 +41,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go index 9da0ea685..e99da0b35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/sentry/platform/ring0/x86.go @@ -104,7 +104,7 @@ const ( VirtualizationException SecurityException = 0x1e SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 + _NR_INTERRUPTS = 0x100 ) // System call vectors. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c0fd3425b..a3f775d15 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -20,6 +21,5 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index e76e498de..b6ebe29d6 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -21,6 +21,8 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", @@ -37,11 +39,12 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 242e6bf76..7d3c4a01c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -36,8 +38,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 8a1d52ebf..163af329b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -52,6 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) s := &socketVFS2{ socketOpsCommon: socketOpsCommon{ @@ -77,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *socketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) @@ -97,11 +105,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// Allocate implements vfs.FileDescriptionImpl.Allocate. -func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.ENODEV -} - // PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index fda3dcb35..faa61160e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -30,6 +30,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -59,6 +62,8 @@ type Stack struct { tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File + ipv4Forwarding bool + ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -118,6 +123,13 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } + s.ipv6Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { + s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") + } + return nil } @@ -468,3 +480,21 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber: + return s.ipv4Forwarding + case ipv6.ProtocolNumber: + return s.ipv6Forwarding + default: + log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) + return false + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 0336a32d8..549787955 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,7 +39,7 @@ type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. name() string - // marshal converts from an stack.Matcher to an ABI struct. + // marshal converts from a stack.Matcher to an ABI struct. marshal(matcher stack.Matcher) []byte // unmarshal converts from the ABI matcher struct to an @@ -93,3 +95,71 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf } return matchMaker.unmarshal(buf, filter) } + +// targetMaker knows how to (un)marshal a target. Once registered, +// marshalTarget and unmarshalTarget can be used. +type targetMaker interface { + // id uniquely identifies the target. + id() stack.TargetID + + // marshal converts from a stack.Target to an ABI struct. + marshal(target stack.Target) []byte + + // unmarshal converts from the ABI matcher struct to a stack.Target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) +} + +// targetMakers maps the TargetID of supported targets to the targetMaker that +// marshals and unmarshals it. It is immutable after package initialization. +var targetMakers = map[stack.TargetID]targetMaker{} + +func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { + tid := stack.TargetID{ + Name: name, + NetworkProtocol: netProto, + Revision: rev, + } + if _, ok := targetMakers[tid]; !ok { + return 0, false + } + + // Return the highest supported revision unless rev is higher. + for _, other := range targetMakers { + otherID := other.id() + if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { + rev = uint8(otherID.Revision) + } + } + return rev, true +} + +// registerTargetMaker should be called by target extensions to register them +// with the netfilter package. +func registerTargetMaker(tm targetMaker) { + if _, ok := targetMakers[tm.id()]; ok { + panic(fmt.Sprintf("multiple targets registered with name %q.", tm.id())) + } + targetMakers[tm.id()] = tm +} + +func marshalTarget(target stack.Target) []byte { + targetMaker, ok := targetMakers[target.ID()] + if !ok { + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + } + return targetMaker.marshal(target) +} + +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { + tid := stack.TargetID{ + Name: target.Name.String(), + NetworkProtocol: filter.NetworkProtocol(), + Revision: target.Revision, + } + targetMaker, ok := targetMakers[tid] + if !ok { + nflog("unsupported target with name %q", target.Name.String()) + return nil, syserr.ErrInvalidArgument + } + return targetMaker.unmarshal(buf, filter) +} diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index e4c55a100..b560fae0d 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -181,18 +181,23 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return nil, syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - table.Rules = append(table.Rules, stack.Rule{ + rule := stack.Rule{ Filter: filter, - Target: target, Matchers: matchers, - }) + } + + { + target, err := parseTarget(filter, optVal[:targetSize], false /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 3b2c1becd..4253f7bf4 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -184,18 +184,23 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return nil, syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - table.Rules = append(table.Rules, stack.Rule{ + rule := stack.Rule{ Filter: filter, - Target: target, Matchers: matchers, - }) + } + + { + target, err := parseTarget(filter, optVal[:targetSize], true /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3e1735079..904a12e38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -195,7 +196,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // Check the user chains. for ruleIdx, rule := range table.Rules { - if _, ok := rule.Target.(stack.UserChainTarget); !ok { + if _, ok := rule.Target.(*stack.UserChainTarget); !ok { continue } @@ -216,7 +217,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // Set each jump to point to the appropriate rule. Right now they hold byte // offsets. for ruleIdx, rule := range table.Rules { - jump, ok := rule.Target.(JumpTarget) + jump, ok := rule.Target.(*JumpTarget) if !ok { continue } @@ -307,7 +308,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case stack.AcceptTarget, stack.DropTarget: + case *stack.AcceptTarget, *stack.DropTarget: return true default: return false @@ -318,7 +319,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(stack.AcceptTarget) + _, ok := rule.Target.(*stack.AcceptTarget) return ok } @@ -337,3 +338,20 @@ func hookFromLinux(hook int) stack.Hook { } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } + +// TargetRevision returns a linux.XTGetRevision for a given target. It sets +// Revision to the highest supported value, unless the provided revision number +// is larger. +func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) { + // Read in the target name and version. + var rev linux.XTGetRevision + if _, err := rev.CopyIn(t, revPtr); err != nil { + return linux.XTGetRevision{}, syserr.FromError(err) + } + maxSupported, ok := targetRevision(rev.Name.String(), netProto, rev.Revision) + if !ok { + return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported + } + rev.Revision = maxSupported + return rev, nil +} diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 87e41abd8..0e14447fe 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,255 +15,357 @@ package netfilter import ( - "errors" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" +func init() { + // Standard targets include ACCEPT, DROP, RETURN, and JUMP. + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + // Both user chains and actual errors are represented in iptables by + // error targets. + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + registerTargetMaker(&redirectTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&nfNATTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) +} -// redirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const redirectTargetName = "REDIRECT" +type standardTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} -func marshalTarget(target stack.Target) []byte { +func (sm *standardTargetMaker) id() stack.TargetID { + // Standard targets have the empty string as a name and no revisions. + return stack.TargetID{ + NetworkProtocol: sm.NetworkProtocol, + } +} +func (*standardTargetMaker) marshal(target stack.Target) []byte { + // Translate verdicts the same way as the iptables tool. + var verdict int32 switch tg := target.(type) { - case stack.AcceptTarget: - return marshalStandardTarget(stack.RuleAccept) - case stack.DropTarget: - return marshalStandardTarget(stack.RuleDrop) - case stack.ErrorTarget: - return marshalErrorTarget(errorTargetName) - case stack.UserChainTarget: - return marshalErrorTarget(tg.Name) - case stack.ReturnTarget: - return marshalStandardTarget(stack.RuleReturn) - case stack.RedirectTarget: - return marshalRedirectTarget(tg) - case JumpTarget: - return marshalJumpTarget(tg) + case *stack.AcceptTarget: + verdict = -linux.NF_ACCEPT - 1 + case *stack.DropTarget: + verdict = -linux.NF_DROP - 1 + case *stack.ReturnTarget: + verdict = linux.NF_RETURN + case *JumpTarget: + verdict = int32(tg.Offset) default: panic(fmt.Errorf("unknown target of type %T", target)) } -} - -func marshalStandardTarget(verdict stack.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. - target := linux.XTStandardTarget{ + xt := linux.XTStandardTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTStandardTarget, }, - Verdict: translateFromStandardVerdict(verdict), + Verdict: verdict, } ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTStandardTarget { + nflog("buf has wrong size for standard target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var standardTarget linux.XTStandardTarget + buf = buf[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict, filter.NetworkProtocol()) + } + // A verdict >= 0 indicates a jump. + return &JumpTarget{ + Offset: uint32(standardTarget.Verdict), + NetworkProtocol: filter.NetworkProtocol(), + }, nil +} + +type errorTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (em *errorTargetMaker) id() stack.TargetID { + // Error targets have no revision. + return stack.TargetID{ + Name: stack.ErrorTargetName, + NetworkProtocol: em.NetworkProtocol, + } } -func marshalErrorTarget(errorName string) []byte { +func (*errorTargetMaker) marshal(target stack.Target) []byte { + var errorName string + switch tg := target.(type) { + case *stack.ErrorTarget: + errorName = stack.ErrorTargetName + case *stack.UserChainTarget: + errorName = tg.Name + default: + panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) + } + // This is an error target named error - target := linux.XTErrorTarget{ + xt := linux.XTErrorTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTErrorTarget, }, } - copy(target.Name[:], errorName) - copy(target.Target.Name[:], errorTargetName) + copy(xt.Name[:], errorName) + copy(xt.Target.Name[:], stack.ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + return binary.Marshal(ret, usermem.ByteOrder, xt) } -func marshalRedirectTarget(rt stack.RedirectTarget) []byte { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTErrorTarget { + nflog("buf has insufficient size for error target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var errorTarget linux.XTErrorTarget + buf = buf[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named stack.ErrorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case stack.ErrorTargetName: + return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + default: + // User defined chain. + return &stack.UserChainTarget{ + Name: name, + NetworkProtocol: filter.NetworkProtocol(), + }, nil + } +} + +type redirectTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *redirectTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*redirectTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) // This is a redirect target named redirect - target := linux.XTRedirectTarget{ + xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(target.Target.Name[:], redirectTargetName) + copy(xt.Target.Name[:], stack.RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) - target.NfRange.RangeSize = 1 - if rt.RangeProtoSpecified { - target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) < linux.SizeOfXTRedirectTarget { + nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("redirectTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var redirectTarget linux.XTRedirectTarget + buf = buf[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + + // RangeSize should be 1. + nfRange := redirectTarget.NfRange + if nfRange.RangeSize != 1 { + nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument } - // Convert port from little endian to big endian. - port := make([]byte, 2) - binary.LittleEndian.PutUint16(port, rt.MinPort) - target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) - binary.LittleEndian.PutUint16(port, rt.MaxPort) - target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) - return binary.Marshal(ret, usermem.ByteOrder, target) + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil } -func marshalJumpTarget(jt JumpTarget) []byte { - nflog("convert to binary: marshalling jump target") +type nfNATTarget struct { + Target linux.XTEntryTarget + Range linux.NFNATRange +} - // The target's name will be the empty string. - target := linux.XTStandardTarget{ +const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange + +type nfNATTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *nfNATTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*nfNATTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) + nt := nfNATTarget{ Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, + TargetSize: nfNATMarhsalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, - // Verdict is overloaded by the ABI. When positive, it holds - // the jump offset from the start of the table. - Verdict: int32(jt.Offset), } + copy(nt.Target.Name[:], stack.RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.Addr) + copy(nt.Range.MaxAddr[:], rt.Addr) - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + nt.Range.MinProto = htons(rt.Port) + nt.Range.MaxProto = nt.Range.MinProto + + ret := make([]byte, 0, nfNATMarhsalledSize) + return binary.Marshal(ret, usermem.ByteOrder, nt) } -// translateFromStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { - switch verdict { - case stack.RuleAccept: - return -linux.NF_ACCEPT - 1 - case stack.RuleDrop: - return -linux.NF_DROP - 1 - case stack.RuleReturn: - return linux.NF_RETURN - default: - // TODO(gvisor.dev/issue/170): Support Jump. - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if size := nfNATMarhsalledSize; len(buf) < size { + nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("nfNATTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] + binary.Unmarshal(buf, usermem.ByteOrder, &natRange) + + // We don't support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("nfNATTargetMaker: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("nfNATTargetMaker: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/3549): Check for other flags. + // For now, redirect target only supports destination change. + if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + } + + return &target, nil } // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32) (stack.Target, error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return stack.AcceptTarget{}, nil + return &stack.AcceptTarget{NetworkProtocol: netProto}, nil case -linux.NF_DROP - 1: - return stack.DropTarget{}, nil + return &stack.DropTarget{NetworkProtocol: netProto}, nil case -linux.NF_QUEUE - 1: - return nil, errors.New("unsupported iptables verdict QUEUE") + nflog("unsupported iptables verdict QUEUE") + return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return stack.ReturnTarget{}, nil + return &stack.ReturnTarget{NetworkProtocol: netProto}, nil default: - return nil, fmt.Errorf("unknown iptables verdict %d", val) + nflog("unknown iptables verdict %d", val) + return nil, syserr.ErrInvalidArgument } } // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { +func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.Target, *syserr.Error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { - return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) + nflog("optVal has insufficient size for entry target %d", len(optVal)) + return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget buf := optVal[:linux.SizeOfXTEntryTarget] binary.Unmarshal(buf, usermem.ByteOrder, &target) - switch target.Name.String() { - case "": - // Standard target. - if len(optVal) != linux.SizeOfXTStandardTarget { - return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) - } - var standardTarget linux.XTStandardTarget - buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - - if standardTarget.Verdict < 0 { - // A Verdict < 0 indicates a non-jump verdict. - return translateToStandardTarget(standardTarget.Verdict) - } - // A verdict >= 0 indicates a jump. - return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil - - case errorTargetName: - // Error target. - if len(optVal) != linux.SizeOfXTErrorTarget { - return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) - } - var errorTarget linux.XTErrorTarget - buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) - - // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named errorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. - // * To mark the start of a user defined chain. These - // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case errorTargetName: - nflog("set entries: error target") - return stack.ErrorTarget{}, nil - default: - // User defined chain. - nflog("set entries: user-defined target %q", name) - return stack.UserChainTarget{Name: name}, nil - } - - case redirectTargetName: - // Redirect target. - if len(optVal) < linux.SizeOfXTRedirectTarget { - return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) - } - - if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p) - } - - var redirectTarget linux.XTRedirectTarget - buf = optVal[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - - // Copy linux.XTRedirectTarget to stack.RedirectTarget. - var target stack.RedirectTarget - nfRange := redirectTarget.NfRange - - // RangeSize should be 1. - if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize) - } - - // TODO(gvisor.dev/issue/170): Check if the flags are valid. - // Also check if we need to map ports or IP. - // For now, redirect target only supports destination port change. - // Port range and IP range are not supported yet. - if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags) - } - target.RangeProtoSpecified = true - - target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) - - // TODO(gvisor.dev/issue/170): Port range is not supported yet. - if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) - } - - // Convert port from big endian to little endian. - port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) - target.MinPort = binary.LittleEndian.Uint16(port) - - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) - target.MaxPort = binary.LittleEndian.Uint16(port) - return target, nil - } - // Unknown target. - return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String()) + return unmarshalTarget(target, filter, optVal) } // JumpTarget implements stack.Target. @@ -274,9 +376,31 @@ type JumpTarget struct { // RuleNum is the rule to jump to. RuleNum int + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (jt *JumpTarget) ID() stack.TargetID { + return stack.TargetID{ + NetworkProtocol: jt.NetworkProtocol, + } } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } + +func ntohs(port uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, port) + return usermem.ByteOrder.Uint16(buf) +} + +func htons(port uint16) uint16 { + buf := make([]byte, 2) + usermem.ByteOrder.PutUint16(buf, port) + return binary.BigEndian.Uint16(buf) +} diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 0bfd6c1f4..844acfede 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber { + return false, false } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 7ed05461d..63201201c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false } + + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 0546801bf..1f926aa91 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -16,6 +16,8 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", @@ -36,8 +38,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index bb205be0d..e8930f031 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -52,6 +52,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 68a9b9a96..5ddcd4be5 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -38,8 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index a38d25da9..c83b23242 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV return fd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 1fb777a6c..fae3b6783 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -22,6 +22,8 @@ go_library( "//pkg/binary", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/safemem", "//pkg/sentry/arch", @@ -51,8 +53,6 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 0bf21f7d8..87e30d742 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -40,6 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -62,8 +64,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -158,6 +158,9 @@ var Metrics = tcpip.Stats{ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."), + IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), + IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), @@ -195,7 +198,6 @@ var Metrics = tcpip.Stats{ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), - InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -857,7 +859,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -866,7 +868,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock { + if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { return ep, wq, syserr.TranslateNetstackError(err) } @@ -879,15 +881,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -907,13 +912,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1187,7 +1187,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam var v tcpip.LingerOption var linger linux.Linger if err := ep.GetSockOpt(&v); err != nil { - return &linger, nil + return nil, syserr.TranslateNetstackError(err) } if v.Enabled { @@ -1512,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &vP, nil case linux.IP6T_ORIGINAL_DST: - // TODO(gvisor.dev/issue/170): ip6tables. - return nil, syserr.ErrInvalidArgument + if outLen < int(binary.Size(linux.SockAddrInet6{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: if outLen < linux.SizeOfIPTGetinfo { @@ -1555,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } return &entries, nil + case linux.IP6T_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1718,6 +1747,26 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in } return &entries, nil + case linux.IPT_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil + default: emitUnimplementedEventIP(t, name) } @@ -1770,10 +1819,16 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int case linux.SOL_IP: return setSockOptIP(t, s, ep, name, optVal) + case linux.SOL_PACKET: + // gVisor doesn't support any SOL_PACKET options just return not + // supported. Returning nil here will result in tcpdump thinking AF_PACKET + // features are supported and proceed to use them and break. + t.Kernel().EmitUnimplementedEvent(t) + return syserr.ErrProtocolNotAvailable + case linux.SOL_UDP, linux.SOL_ICMPV6, - linux.SOL_RAW, - linux.SOL_PACKET: + linux.SOL_RAW: t.Kernel().EmitUnimplementedEvent(t) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 1f7d17f5f..4c6791fff 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -18,6 +18,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -29,8 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) @@ -151,14 +158,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs // tcpip.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -176,13 +187,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { + if peerAddr != nil { // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f0fe18684..1028d2a6e 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcp.ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -166,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcp.ReceiveBufferSizeOption{ + rs := tcpip.TCPReceiveBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError() } // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcp.SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -187,29 +187,30 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcp.SendBufferSizeOption{ + ss := tcpip.TCPSendBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError() } // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcp.SACKEnabled + var sack tcpip.TCPSACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() + opt := tcpip.TCPSACKEnabled(enabled) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // TCPRecovery implements inet.Stack.TCPRecovery. func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { - var recovery tcp.Recovery + var recovery tcpip.TCPRecovery if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } @@ -218,7 +219,8 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { // SetTCPRecovery implements inet.Stack.SetTCPRecovery. func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() + opt := tcpip.TCPRecovery(recovery) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // Statistics implements inet.Stack.Statistics. @@ -410,3 +412,24 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { s.Stack.RestoreCleanupEndpoints(es) } + +// Forwarding implements inet.Stack.Forwarding. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + return s.Stack.Forwarding(protocol) + default: + panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) + } +} + +// SetForwarding implements inet.Stack.SetForwarding. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { + switch protocol { + case ipv4.ProtocolNumber, ipv6.ProtocolNumber: + s.Stack.SetForwarding(protocol, enable) + default: + panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + } + return nil +} diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 04b259d27..fd31479e5 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -35,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cb953e4dc..cc7408698 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -7,10 +7,21 @@ go_template_instance( name = "socket_refs", out = "socket_refs.go", package = "unix", - prefix = "socketOpsCommon", + prefix = "socketOperations", template = "//pkg/refs_vfs2:refs_template", types = { - "T": "socketOpsCommon", + "T": "SocketOperations", + }, +) + +go_template_instance( + name = "socket_vfs2_refs", + out = "socket_vfs2_refs.go", + package = "unix", + prefix = "socketVFS2", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SocketVFS2", }, ) @@ -20,6 +31,7 @@ go_library( "device.go", "io.go", "socket_refs.go", + "socket_vfs2_refs.go", "unix.go", "unix_vfs2.go", ], @@ -29,6 +41,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", @@ -49,6 +62,5 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index e3a75b519..aa4f3c04d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() @@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { select { case ne := <-e.acceptedChan: + if peerAddr != nil { + ne.Lock() + c := ne.connected + ne.Unlock() + if c != nil { + addr, err := c.GetLocalAddress() + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + *peerAddr = addr + } + } return ne, nil default: diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 4751b2fd8..f8aacca13 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } // Listen starts listening on the connection. -func (e *connectionlessEndpoint) Listen(int) *syserr.Error { +func (*connectionlessEndpoint) Listen(int) *syserr.Error { return syserr.ErrNotSupported } // Accept accepts a new connection. -func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index cc9d650fb..d6fc03520 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -151,7 +151,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *syserr.Error) + // + // peerAddr if not nil will be populated with the address of the connected + // peer on a successful accept. + Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -743,6 +746,9 @@ type baseEndpoint struct { // path is not empty if the endpoint has been bound, // or may be used if the endpoint is connected. path string + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // EventRegister implements waiter.Waitable.EventRegister. @@ -838,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return n, err } -// SetSockOpt sets a socket option. Currently not supported. -func (e *baseEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { +// SetSockOpt sets a socket option. +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + e.linger = *v + e.Unlock() + } return nil } @@ -942,8 +954,11 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch o := opt.(type) { case *tcpip.LingerOption: + e.Lock() + *o = e.linger + e.Unlock() return nil default: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0a7a26495..f80011ce4 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -39,7 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -55,6 +55,7 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + socketOperationsRefs socketOpsCommon } @@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty return fs.NewFile(ctx, d, flags, &s) } +// DecRef implements RefCounter.DecRef. +func (s *SocketOperations) DecRef(ctx context.Context) { + s.socketOperationsRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implemements fs.FileOperations.Release. +func (s *SocketOperations) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // socketOpsCommon contains the socket operations common to VFS1 and VFS2. // // +stateify savable type socketOpsCommon struct { - socketOpsCommonRefs socket.SendReceiveTimeout ep transport.Endpoint @@ -101,23 +118,6 @@ type socketOpsCommon struct { abstractNamespace *kernel.AbstractSocketNamespace } -// DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef(ctx context.Context) { - s.socketOpsCommonRefs.DecRef(func() { - s.ep.Close(ctx) - if s.abstractNamespace != nil { - s.abstractNamespace.Remove(s.abstractName, s) - } - }) -} - -// Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(ctx context.Context) { - // Release only decrements a reference on s because s may be referenced in - // the abstract socket namespace. - s.DecRef(ctx) -} - func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: @@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 65a285b8f..3345124cc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -32,17 +33,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, // vfs.FileDescriptionImpl) for Unix sockets. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.DentryMetadataFileDescriptionImpl vfs.LockFD + socketVFS2Refs socketOpsCommon } @@ -53,6 +56,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) if err != nil { @@ -88,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 return vfsfd, nil } +// DecRef implements RefCounter.DecRef. +func (s *SocketVFS2) DecRef(ctx context.Context) { + s.socketVFS2Refs.DecRef(func() { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { @@ -96,7 +119,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.socketOpsCommon.EventRegister(&e, waiter.EventIn) @@ -105,7 +128,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -118,15 +141,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -144,13 +170,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index a06c9b8ab..245d2c5cf 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -61,8 +61,10 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() - defer k.Unpause() - defer log.Infof("Tasks resumed after save.") + defer func() { + k.Unpause() + log.Infof("Tasks resumed after save.") + }() w.Stop() defer w.Start() diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 88d5db9fc..a920180d3 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", + "//pkg/marshal/primitive", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index 5d51a7792..ae3b998c8 100644 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go @@ -26,7 +26,7 @@ import ( func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { var e linux.EpollEvent - if _, err := t.CopyIn(eventAddr, &e); err != nil { + if _, err := e.CopyIn(t, eventAddr); err != nil { return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) } var sb strings.Builder @@ -41,7 +41,7 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui addr := eventsAddr for i := uint64(0); i < numEvents; i++ { var e linux.EpollEvent - if _, err := t.CopyIn(addr, &e); err != nil { + if _, err := e.CopyIn(t, addr); err != nil { fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err) continue } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 08e97e6c4..cc5f70cd4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" @@ -166,7 +167,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } buf := make([]byte, length) - if _, err := t.CopyIn(addr, &buf); err != nil { + if _, err := t.CopyInBytes(addr, buf); err != nil { return fmt.Sprintf("%#x (error decoding control: %v)", addr, err) } @@ -302,7 +303,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string { var msg slinux.MessageHeader64 - if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil { + if _, err := msg.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err) } s := fmt.Sprintf( @@ -380,9 +381,9 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) { // socklen_t is 32-bits. - var l uint32 - _, err := t.CopyIn(addr, &l) - return l, err + var l primitive.Uint32 + _, err := l.CopyIn(t, addr) + return uint32(l), err } func sockLenPointer(t *kernel.Task, addr usermem.Addr) string { @@ -436,22 +437,22 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string { switch optLen { case 1: - var v uint8 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint8 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 2: - var v uint16 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint16 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 4: - var v uint32 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint32 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 87b239730..396744597 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -17,17 +17,16 @@ package strace import ( - "encoding/binary" "fmt" "strconv" "strings" - "syscall" "time" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -91,7 +90,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma } b := make([]byte, size) - amt, err := t.CopyIn(ar.Start, b) + amt, err := t.CopyInBytes(ar.Start, b) if err != nil { iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err) continue @@ -118,7 +117,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st } b := make([]byte, size) - amt, err := t.CopyIn(addr, b) + amt, err := t.CopyInBytes(addr, b) if err != nil { return fmt.Sprintf("%#x (error decoding string: %s)", addr, err) } @@ -199,7 +198,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { func fdpair(t *kernel.Task, addr usermem.Addr) string { var fds [2]int32 - _, err := t.CopyIn(addr, &fds) + _, err := primitive.CopyInt32SliceIn(t, addr, fds[:]) if err != nil { return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err) } @@ -209,7 +208,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string { func uname(t *kernel.Task, addr usermem.Addr) string { var u linux.UtsName - if _, err := t.CopyIn(addr, &u); err != nil { + if _, err := u.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err) } @@ -222,7 +221,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } @@ -244,7 +243,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec) @@ -256,7 +255,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timeval - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err) } @@ -268,8 +267,8 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string { return "null" } - var utim syscall.Utimbuf - if _, err := t.CopyIn(addr, &utim); err != nil { + var utim linux.Utime + if _, err := utim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err) } @@ -282,7 +281,7 @@ func stat(t *kernel.Task, addr usermem.Addr) string { } var stat linux.Stat - if _, err := t.CopyIn(addr, &stat); err != nil { + if _, err := stat.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err) } return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec)) @@ -294,7 +293,7 @@ func itimerval(t *kernel.Task, addr usermem.Addr) string { } interval := timeval(t, addr) - value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{}))) + value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } @@ -304,7 +303,7 @@ func itimerspec(t *kernel.Task, addr usermem.Addr) string { } interval := timespec(t, addr) - value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{}))) + value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } @@ -330,7 +329,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string { } var ru linux.Rusage - if _, err := t.CopyIn(addr, &ru); err != nil { + if _, err := ru.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err) } return fmt.Sprintf("%#x %+v", addr, ru) @@ -342,7 +341,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(addr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding header: %s)", addr, err) } @@ -367,7 +366,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err) } @@ -376,7 +375,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { switch hdr.Version { case linux.LINUX_CAPABILITY_VERSION_1: var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data.Permitted) @@ -384,7 +383,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { e = uint64(data.Effective) case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 4a9b04fd0..75752b2e6 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -56,6 +56,7 @@ go_library( "sys_xattr.go", "timespec.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ "//pkg/abi", @@ -64,6 +65,8 @@ go_library( "//pkg/bpf", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/rand", "//pkg/safemem", @@ -99,7 +102,5 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index da6bd85e1..5f26697d2 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -138,7 +138,7 @@ var AMD64 = &kernel.SyscallTable{ 83: syscalls.Supported("mkdir", Mkdir), 84: syscalls.Supported("rmdir", Rmdir), 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), + 86: syscalls.PartiallySupported("link", Link, "Limited support with Gofer. Link count and linked files may get out of sync because gVisor is not aware of external hardlinks.", nil), 87: syscalls.Supported("unlink", Unlink), 88: syscalls.Supported("symlink", Symlink), 89: syscalls.Supported("readlink", Readlink), @@ -317,7 +317,7 @@ var AMD64 = &kernel.SyscallTable{ 262: syscalls.Supported("fstatat", Fstatat), 263: syscalls.Supported("unlinkat", Unlinkat), 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), + 265: syscalls.PartiallySupported("linkat", Linkat, "See link(2).", nil), 266: syscalls.Supported("symlinkat", Symlinkat), 267: syscalls.Supported("readlinkat", Readlinkat), 268: syscalls.Supported("fchmodat", Fchmodat), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index e9d64dec5..0bf313a13 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -36,7 +37,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // // The context pointer _must_ be zero initially. var idIn uint64 - if _, err := t.CopyIn(idAddr, &idIn); err != nil { + if _, err := primitive.CopyUint64In(t, idAddr, &idIn); err != nil { return 0, nil, err } if idIn != 0 { @@ -49,7 +50,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Copy out the new ID. - if _, err := t.CopyOut(idAddr, &id); err != nil { + if _, err := primitive.CopyUint64Out(t, idAddr, id); err != nil { t.MemoryManager().DestroyAIOContext(t, id) return 0, nil, err } @@ -142,7 +143,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S ev := v.(*linux.IOEvent) // Copy out the result. - if _, err := t.CopyOut(eventsAddr, ev); err != nil { + if _, err := ev.CopyOut(t, eventsAddr); err != nil { if count > 0 { return uintptr(count), nil, nil } @@ -338,21 +339,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go index adf5ea5f2..d3b85e11b 100644 --- a/pkg/sentry/syscalls/linux/sys_capability.go +++ b/pkg/sentry/syscalls/linux/sys_capability.go @@ -45,7 +45,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } // hdr.Pid doesn't need to be valid if this capget() is a "version probe" @@ -65,7 +65,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Permitted: uint32(p), Inheritable: uint32(i), } - _, err = t.CopyOut(dataAddr, &data) + _, err = data.CopyOut(t, dataAddr) return 0, nil, err case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: @@ -88,12 +88,12 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Inheritable: uint32(i >> 32), }, } - _, err = t.CopyOut(dataAddr, &data) + _, err = linux.CopyCapUserDataSliceOut(t, dataAddr, data[:]) return 0, nil, err default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } if dataAddr != 0 { @@ -109,7 +109,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } switch hdr.Version { @@ -118,7 +118,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return 0, nil, err } p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities @@ -131,7 +131,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return 0, nil, err } p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities @@ -141,7 +141,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 256422689..98331eb3c 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" @@ -601,19 +602,19 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Shared flags between file and socket. switch request { case linux.FIONCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -627,7 +628,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -641,15 +642,14 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOSETOWN, linux.SIOCSPGRP: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } fSetOwn(t, file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: - who := fGetOwn(t, file) - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), fGetOwn(t, file)) return 0, nil, err default: @@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Top it off with a terminator. - _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00")) + _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00")) return uintptr(bytes + 1), nil, err } @@ -787,7 +787,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - file, _ := t.FDTable().Remove(fd) + file, _ := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } @@ -941,7 +941,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlags(fd, kernel.FDFlags{ + err := t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -962,7 +962,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return 0, nil, err } @@ -1052,12 +1052,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) - _, err := t.CopyOut(addr, &owner) + _, err := owner.CopyOut(t, addr) return 0, nil, err case linux.F_SETOWN_EX: addr := args[2].Pointer() var owner linux.FOwnerEx - _, err := t.CopyIn(addr, &owner) + _, err := owner.CopyIn(t, addr) if err != nil { return 0, nil, err } @@ -1154,6 +1154,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.ThenChange(vfs2/fd.go) + +// LINT.IfChange + func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1918,7 +1922,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times linux.Utime - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := times.CopyIn(t, timesAddr); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1938,7 +1942,7 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1966,7 +1970,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) { @@ -2000,7 +2004,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if times[0].Usec >= 1e6 || times[0].Usec < 0 || diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index 12b2fa690..f39ce0639 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -306,8 +306,8 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel // Despite the syscall using the name 'pid' for this variable, it is // very much a tid. tid := args[0].Int() - head := args[1].Pointer() - size := args[2].Pointer() + headAddr := args[1].Pointer() + sizeAddr := args[2].Pointer() if tid < 0 { return 0, nil, syserror.EINVAL @@ -321,12 +321,16 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel } // Copy out head pointer. - if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil { + head := t.Arch().Native(uintptr(ot.GetRobustList())) + if _, err := head.CopyOut(t, headAddr); err != nil { return 0, nil, err } - // Copy out size, which is a constant. - if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil { + // Copy out size, which is a constant. Note that while size isn't + // an address, it is defined as the arch-dependent size_t, so it + // needs to be converted to a native-sized int. + size := t.Arch().Native(uintptr(linux.SizeOfRobustListHead)) + if _, err := size.CopyOut(t, sizeAddr); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index 59004cefe..b25f7d881 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -19,7 +19,6 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -93,19 +92,23 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir } } -// oldDirentHdr is a fixed sized header matching the fixed size -// fields found in the old linux dirent struct. +// oldDirentHdr is a fixed sized header matching the fixed size fields found in +// the old linux dirent struct. +// +// +marshal type oldDirentHdr struct { Ino uint64 Off uint64 - Reclen uint16 + Reclen uint16 `marshal:"unaligned"` // Struct ends mid-word. } -// direntHdr is a fixed sized header matching the fixed size -// fields found in the new linux dirent struct. +// direntHdr is a fixed sized header matching the fixed size fields found in the +// new linux dirent struct. +// +// +marshal type direntHdr struct { OldHdr oldDirentHdr - Typ uint8 + Typ uint8 `marshal:"unaligned"` // Struct ends mid-word. } // dirent contains the data pointed to by a new linux dirent struct. @@ -134,20 +137,20 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent // the old linux dirent format. func smallestDirent(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1 + return uint(d.Hdr.OldHdr.SizeBytes()) + a.Width() + 1 } // smallestDirent64 returns the size of the smallest possible dirent using // the new linux dirent format. func smallestDirent64(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr)) + a.Width() + return uint(d.Hdr.SizeBytes()) + a.Width() } // padRec pads the name field until the rec length is a multiple of the width, // which must be a power of 2. It returns the padded rec length. func (d *dirent) padRec(width int) uint16 { - a := int(binary.Size(d.Hdr)) + len(d.Name) + a := d.Hdr.SizeBytes() + len(d.Name) r := (a + width) &^ (width - 1) padding := r - a d.Name = append(d.Name, make([]byte, padding)...) @@ -157,7 +160,7 @@ func (d *dirent) padRec(width int) uint16 { // Serialize64 serializes a Dirent struct to a byte slice, keeping the new // linux dirent format. Returns the number of bytes serialized or an error. func (d *dirent) Serialize64(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr)) + n1, err := d.Hdr.WriteTo(w) if err != nil { return 0, err } @@ -165,14 +168,14 @@ func (d *dirent) Serialize64(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2, nil + return int(n1) + n2, nil } // Serialize serializes a Dirent struct to a byte slice, using the old linux // dirent format. // Returns the number of bytes serialized or an error. func (d *dirent) Serialize(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr)) + n1, err := d.Hdr.OldHdr.WriteTo(w) if err != nil { return 0, err } @@ -184,7 +187,7 @@ func (d *dirent) Serialize(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2 + n3, nil + return int(n1) + n2 + n3, nil } // direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go index 715ac45e6..a29d307e5 100644 --- a/pkg/sentry/syscalls/linux/sys_identity.go +++ b/pkg/sentry/syscalls/linux/sys_identity.go @@ -49,13 +49,13 @@ func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ruid := c.RealKUID.In(c.UserNamespace).OrOverflow() euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow() suid := c.SavedKUID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(ruidAddr, ruid); err != nil { + if _, err := ruid.CopyOut(t, ruidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(euidAddr, euid); err != nil { + if _, err := euid.CopyOut(t, euidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(suidAddr, suid); err != nil { + if _, err := suid.CopyOut(t, suidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -84,13 +84,13 @@ func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys rgid := c.RealKGID.In(c.UserNamespace).OrOverflow() egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow() sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(rgidAddr, rgid); err != nil { + if _, err := rgid.CopyOut(t, rgidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(egidAddr, egid); err != nil { + if _, err := egid.CopyOut(t, egidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(sgidAddr, sgid); err != nil { + if _, err := sgid.CopyOut(t, sgidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -157,7 +157,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys for i, kgid := range kgids { gids[i] = kgid.In(t.UserNamespace()).OrOverflow() } - if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil { + if _, err := auth.CopyGIDSliceOut(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return uintptr(len(gids)), nil, nil @@ -173,7 +173,7 @@ func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, t.SetExtraGIDs(nil) } gids := make([]auth.GID, size) - if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil { + if _, err := auth.CopyGIDSliceIn(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return 0, nil, t.SetExtraGIDs(gids) diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index d0109baa4..cd8dfdfa4 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -100,6 +100,15 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with a special mappable. + opts.Offset = 0 + m, err := mm.NewSharedAnonMappable(opts.Length, t.Kernel()) + if err != nil { + return 0, nil, err + } + opts.MappingIdentity = m // transfers ownership of m to opts + opts.Mappable = m } rv, err := t.MemoryManager().MMap(t, opts) @@ -239,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, syserror.ENOMEM } resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize)) - _, err := t.CopyOut(vec, resident) + _, err := t.CopyOutBytes(vec, resident) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 3149e4aad..849a47476 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -46,9 +47,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { return 0, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { + if file, _ := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 789e2ed5b..254f4c9f9 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -162,7 +162,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -189,7 +189,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -202,7 +202,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -329,19 +329,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 64a725296..a892d2c62 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -43,7 +44,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, nil case linux.PR_GET_PDEATHSIG: - _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal())) + _, err := primitive.CopyInt32Out(t, args[1].Pointer(), int32(t.ParentDeathSignal())) return 0, nil, err case linux.PR_GET_DUMPABLE: @@ -110,7 +111,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall buf[len] = 0 len++ } - _, err := t.CopyOut(addr, buf[:len]) + _, err := t.CopyOutBytes(addr, buf[:len]) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index d5d5b6959..309c183a3 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -26,17 +27,13 @@ import ( // rlimit describes an implementation of 'struct rlimit', which may vary from // system-to-system. type rlimit interface { + marshal.Marshallable + // toLimit converts an rlimit to a limits.Limit. toLimit() *limits.Limit // fromLimit converts a limits.Limit to an rlimit. fromLimit(lim limits.Limit) - - // copyIn copies an rlimit from the untrusted app to the kernel. - copyIn(t *kernel.Task, addr usermem.Addr) error - - // copyOut copies an rlimit from the kernel to the untrusted app. - copyOut(t *kernel.Task, addr usermem.Addr) error } // newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system. @@ -50,6 +47,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) { } } +// +marshal type rlimit64 struct { Cur uint64 Max uint64 @@ -70,12 +68,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) { } func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyIn(addr, r) + _, err := r.CopyIn(t, addr) return err } func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyOut(addr, *r) + _, err := r.CopyOut(t, addr) return err } @@ -140,7 +138,8 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } rlim.fromLimit(lim) - return 0, nil, rlim.copyOut(t, addr) + _, err = rlim.CopyOut(t, addr) + return 0, nil, err } // Setrlimit implements linux syscall setrlimit(2). @@ -155,7 +154,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err != nil { return 0, nil, err } - if err := rlim.copyIn(t, addr); err != nil { + if _, err := rlim.CopyIn(t, addr); err != nil { return 0, nil, syserror.EFAULT } _, err = prlimit64(t, resource, rlim.toLimit()) diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go index 1674c7445..ac5c98a54 100644 --- a/pkg/sentry/syscalls/linux/sys_rusage.go +++ b/pkg/sentry/syscalls/linux/sys_rusage.go @@ -80,7 +80,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } ru := getrusage(t, which) - _, err := t.CopyOut(addr, &ru) + _, err := ru.CopyOut(t, addr) return 0, nil, err } @@ -104,7 +104,7 @@ func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall CUTime: linux.ClockTFromDuration(cs2.UserTime), CSTime: linux.ClockTFromDuration(cs2.SysTime), } - if _, err := t.CopyOut(addr, &r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go index 99f6993f5..bfcf44b6f 100644 --- a/pkg/sentry/syscalls/linux/sys_sched.go +++ b/pkg/sentry/syscalls/linux/sys_sched.go @@ -27,8 +27,10 @@ const ( ) // SchedParam replicates struct sched_param in sched.h. +// +// +marshal type SchedParam struct { - schedPriority int64 + schedPriority int32 } // SchedGetparam implements linux syscall sched_getparam(2). @@ -45,7 +47,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.ESRCH } r := SchedParam{schedPriority: onlyPriority} - if _, err := t.CopyOut(param, r); err != nil { + if _, err := r.CopyOut(t, param); err != nil { return 0, nil, err } @@ -79,7 +81,7 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke return 0, nil, syserror.ESRCH } var r SchedParam - if _, err := t.CopyIn(param, &r); err != nil { + if _, err := r.CopyIn(t, param); err != nil { return 0, nil, syserror.EINVAL } if r.schedPriority != onlyPriority { diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index 5b7a66f4d..4fdb4463c 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -24,6 +24,8 @@ import ( ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. +// +// +marshal type userSockFprog struct { // Len is the length of the filter in BPF instructions. Len uint16 @@ -33,7 +35,7 @@ type userSockFprog struct { // Filter is a user pointer to the struct sock_filter array that makes up // the filter program. Filter is a uint64 rather than a usermem.Addr // because usermem.Addr is actually uintptr, which is not a fixed-size - // type, and encoding/binary.Read objects to this. + // type. Filter uint64 } @@ -54,11 +56,11 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error { } var fprog userSockFprog - if _, err := t.CopyIn(addr, &fprog); err != nil { + if _, err := fprog.CopyIn(t, addr); err != nil { return err } filter := make([]linux.BPFInstruction, int(fprog.Len)) - if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil { + if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil { return err } compiledFilter, err := bpf.Compile(filter) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 5f54f2456..47dadb800 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -66,7 +67,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } ops := make([]linux.Sembuf, nsops) - if _, err := t.CopyIn(sembufAddr, ops); err != nil { + if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil { return 0, nil, err } @@ -116,8 +117,8 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case linux.IPC_SET: arg := args[3].Pointer() - s := linux.SemidDS{} - if _, err := t.CopyIn(arg, &s); err != nil { + var s linux.SemidDS + if _, err := s.CopyIn(t, arg); err != nil { return 0, nil, err } @@ -188,7 +189,7 @@ func setValAll(t *kernel.Task, id int32, array usermem.Addr) error { return syserror.EINVAL } vals := make([]uint16, set.Size()) - if _, err := t.CopyIn(array, vals); err != nil { + if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil { return err } creds := auth.CredentialsFromContext(t) @@ -217,7 +218,7 @@ func getValAll(t *kernel.Task, id int32, array usermem.Addr) error { if err != nil { return err } - _, err = t.CopyOut(array, vals) + _, err = primitive.CopyUint16SliceOut(t, array, vals) return err } diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index f0ae8fa8e..584064143 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -112,18 +112,18 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal stat, err := segment.IPCStat(t) if err == nil { - _, err = t.CopyOut(buf, stat) + _, err = stat.CopyOut(t, buf) } return 0, nil, err case linux.IPC_INFO: params := r.IPCInfo() - _, err := t.CopyOut(buf, params) + _, err := params.CopyOut(t, buf) return 0, nil, err case linux.SHM_INFO: info := r.ShmInfo() - _, err := t.CopyOut(buf, info) + _, err := info.CopyOut(t, buf) return 0, nil, err } @@ -137,11 +137,10 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal switch cmd { case linux.IPC_SET: var ds linux.ShmidDS - _, err = t.CopyIn(buf, &ds) - if err != nil { + if _, err = ds.CopyIn(t, buf); err != nil { return 0, nil, err } - err = segment.Set(t, &ds) + err := segment.Set(t, &ds) return 0, nil, err case linux.IPC_RMID: diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 38f573c14..9feaca0da 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -29,8 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -67,10 +67,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -78,6 +78,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -106,30 +108,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -148,10 +134,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -160,7 +146,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -173,7 +159,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -247,9 +234,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Copy the file descriptors out. - if _, err := t.CopyOut(socks, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { + if file, _ := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } @@ -456,8 +443,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -471,7 +458,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } @@ -733,7 +720,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -748,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -780,7 +767,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -817,17 +804,17 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -996,7 +983,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1011,7 +998,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1022,7 +1009,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index c69941feb..46616c961 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -141,7 +142,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Copy in the offset. var offset int64 - if _, err := t.CopyIn(offsetAddr, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, offsetAddr, &offset); err != nil { return 0, nil, err } @@ -149,11 +150,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, SrcOffset: true, - SrcStart: offset, + SrcStart: int64(offset), }, outFile.Flags().NonBlocking) // Copy out the new offset. - if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { + if _, err := primitive.CopyInt64Out(t, offsetAddr, offset+n); err != nil { return 0, nil, err } } else { @@ -228,7 +229,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(outOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffset, &offset); err != nil { return 0, nil, err } @@ -246,7 +247,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(inOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffset, &offset); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index a5826f2dd..cda29a8b5 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -221,7 +221,7 @@ func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr DevMajor: uint32(devMajor), DevMinor: devMinor, } - _, err := t.CopyOut(statxAddr, &s) + _, err := s.CopyOut(t, statxAddr) return err } @@ -283,7 +283,7 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { FragmentSize: d.Inode.StableAttr.BlockSize, // Leave other fields 0 like simple_statfs does. } - _, err = t.CopyOut(addr, &statfs) + _, err = statfs.CopyOut(t, addr) return err } diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 297de052a..674d341b6 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -43,6 +43,6 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca FreeRAM: memFree, Unit: 1, } - _, err := t.CopyOut(addr, si) + _, err := si.CopyOut(t, addr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 101096038..39ca9ea97 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -19,6 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -311,13 +312,13 @@ func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusage return 0, err } if statusAddr != 0 { - if _, err := t.CopyOut(statusAddr, wr.Status); err != nil { + if _, err := primitive.CopyUint32Out(t, statusAddr, wr.Status); err != nil { return 0, err } } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, err } } @@ -395,14 +396,14 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // as well. if infop != 0 { var si arch.SignalInfo - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) } } return 0, nil, err } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, nil, err } } @@ -441,7 +442,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: t.Warningf("waitid got incomprehensible wait status %d", s) } - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) return 0, nil, err } @@ -558,9 +559,7 @@ func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // third argument to this system call is nowadays unused. if cpu != 0 { - buf := t.CopyScratchBuffer(4) - usermem.ByteOrder.PutUint32(buf, uint32(t.CPU())) - if _, err := t.CopyOutBytes(cpu, buf); err != nil { + if _, err := primitive.CopyInt32Out(t, cpu, t.CPU()); err != nil { return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index a2a24a027..c5054d2f1 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -168,7 +169,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(r), nil, nil } - if _, err := t.CopyOut(addr, r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } return uintptr(r), nil, nil @@ -334,8 +335,8 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // Ask the time package for the timezone. _, offset := time.Now().Zone() // This int32 array mimics linux's struct timezone. - timezone := [2]int32{-int32(offset) / 60, 0} - _, err := t.CopyOut(tz, timezone) + timezone := []int32{-int32(offset) / 60, 0} + _, err := primitive.CopyInt32SliceOut(t, tz, timezone) return 0, nil, err } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go index a4c400f87..45eef4feb 100644 --- a/pkg/sentry/syscalls/linux/sys_timer.go +++ b/pkg/sentry/syscalls/linux/sys_timer.go @@ -21,81 +21,63 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const nsecPerSec = int64(time.Second) -// copyItimerValIn copies an ItimerVal from the untrusted app range to the -// kernel. The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed because because Linux allows -// setitimer(which, NULL, &old_value) which disables the timer. -// There is a KERN_WARN message saying this misfeature will be removed. -// However, that hasn't happened as of 3.19, so we continue to support it. -func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) { - if addr == usermem.Addr(0) { - return linux.ItimerVal{}, nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - var itv linux.ItimerVal - if _, err := t.CopyIn(addr, &itv); err != nil { - return linux.ItimerVal{}, err - } - - return itv, nil - default: - return linux.ItimerVal{}, syserror.ENOSYS - } -} - -// copyItimerValOut copies an ItimerVal to the untrusted app range. -// The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed, in which case no copy takes place -func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error { - if addr == usermem.Addr(0) { - return nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - _, err := t.CopyOut(addr, itv) - return err - default: - return syserror.ENOSYS - } -} - // Getitimer implements linux syscall getitimer(2). func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } + timerID := args[0].Int() - val := args[1].Pointer() + addr := args[1].Pointer() olditv, err := t.Getitimer(timerID) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, val, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if addr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, addr) + return 0, nil, err } // Setitimer implements linux syscall setitimer(2). func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - timerID := args[0].Int() - newVal := args[1].Pointer() - oldVal := args[2].Pointer() + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } - newitv, err := copyItimerValIn(t, newVal) - if err != nil { - return 0, nil, err + timerID := args[0].Int() + newAddr := args[1].Pointer() + oldAddr := args[2].Pointer() + + var newitv linux.ItimerVal + // A NULL address is allowed because because Linux allows + // setitimer(which, NULL, &old_value) which disables the timer. There is a + // KERN_WARN message saying this misfeature will be removed. However, that + // hasn't happened as of 3.19, so we continue to support it. + if newAddr != 0 { + if _, err := newitv.CopyIn(t, newAddr); err != nil { + return 0, nil, err + } } olditv, err := t.Setitimer(timerID, newitv) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, oldVal, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if oldAddr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, oldAddr) + return 0, nil, err } // Alarm implements linux syscall alarm(2). @@ -131,7 +113,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S var sev *linux.Sigevent if sevp != 0 { sev = &linux.Sigevent{} - if _, err = t.CopyIn(sevp, sev); err != nil { + if _, err = sev.CopyIn(t, sevp); err != nil { return 0, nil, err } } @@ -141,7 +123,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } - if _, err := t.CopyOut(timerIDp, &id); err != nil { + if _, err := id.CopyOut(t, timerIDp); err != nil { t.IntervalTimerDelete(id) return 0, nil, err } @@ -157,7 +139,7 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. oldValAddr := args[3].Pointer() var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0) @@ -165,9 +147,8 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, err } if oldValAddr != 0 { - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { - return 0, nil, err - } + _, err = oldVal.CopyOut(t, oldValAddr) + return 0, nil, err } return 0, nil, nil } @@ -181,7 +162,7 @@ func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if err != nil { return 0, nil, err } - _, err = t.CopyOut(curValAddr, &curVal) + _, err = curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go index 34b03e4ee..cadd9d348 100644 --- a/pkg/sentry/syscalls/linux/sys_timerfd.go +++ b/pkg/sentry/syscalls/linux/sys_timerfd.go @@ -81,7 +81,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock()) @@ -91,7 +91,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tf.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -116,6 +116,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tf.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go index b3eb96a1c..6ddd30d5c 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go @@ -18,6 +18,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -30,17 +31,19 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys case linux.ARCH_GET_FS: addr := args[1].Pointer() fsbase := t.Arch().TLS() - _, err := t.CopyOut(addr, uint64(fsbase)) - if err != nil { - return 0, nil, err + switch t.Arch().Width() { + case 8: + if _, err := primitive.CopyUint64Out(t, addr, uint64(fsbase)); err != nil { + return 0, nil, err + } + default: + return 0, nil, syserror.ENOSYS } - case linux.ARCH_SET_FS: fsbase := args[1].Uint64() if !t.Arch().SetTLS(uintptr(fsbase)) { return 0, nil, syserror.EPERM } - case linux.ARCH_GET_GS, linux.ARCH_SET_GS: t.Kernel().EmitUnimplementedEvent(t) fallthrough diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index e9d702e8e..66c5974f5 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -46,7 +46,7 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy out the result. va := args[0].Pointer() - _, err := t.CopyOut(va, u) + _, err := u.CopyOut(t, va) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 64696b438..9ee766552 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -44,6 +44,9 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/gohacks", + "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", @@ -72,7 +75,5 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 42559bf69..6d0a38330 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -38,21 +39,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. return uintptr(i), nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index c62f03509..d0cbb77eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -24,7 +24,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -141,50 +140,26 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent, - // maxEvents), so that the buffer can be allocated on the stack. + // Allocate space for a few events on the stack for the common case in + // which we don't have too many events. var ( - events [16]linux.EpollEvent - total int + eventsArr [16]linux.EpollEvent ch chan struct{} haveDeadline bool deadline ktime.Time ) for { - batchEvents := len(events) - if batchEvents > maxEvents { - batchEvents = maxEvents - } - n := ep.ReadEvents(events[:batchEvents]) - maxEvents -= n - if n != 0 { - // Copy what we read out. - copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n]) + events := ep.ReadEvents(eventsArr[:0], maxEvents) + if len(events) != 0 { + copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events) copiedEvents := copiedBytes / sizeofEpollEvent // rounded down - eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent) - total += copiedEvents - if err != nil { - if total != 0 { - return uintptr(total), nil, nil - } - return 0, nil, err - } - // If we've filled the application's event buffer, we're done. - if maxEvents == 0 { - return uintptr(total), nil, nil - } - // Loop if we read a full batch, under the expectation that there - // may be more events to read. - if n == batchEvents { - continue + if copiedEvents != 0 { + return uintptr(copiedEvents), nil, nil } + return 0, nil, err } - // We get here if n != batchEvents. If we read any number of events - // (just now, or in a previous iteration of this loop), or if timeout - // is 0 (such that epoll_wait should be non-blocking), return the - // events we've read so far to the application. - if total != 0 || timeout == 0 { - return uintptr(total), nil, nil + if timeout == 0 { + return 0, nil, nil } // In the first iteration of this loop, register with the epoll // instance for readability events, but then immediately continue the @@ -207,8 +182,6 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err == syserror.ETIMEDOUT { err = nil } - // total must be 0 since otherwise we would have returned - // above. return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 4856554fe..d8b8d9783 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -34,7 +34,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - _, file := t.FDTable().Remove(fd) + _, file := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } @@ -137,7 +137,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + err := t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -181,11 +181,11 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if !hasOwner { return 0, nil, nil } - _, err := t.CopyOut(args[2].Pointer(), &owner) + _, err := owner.CopyOut(t, args[2].Pointer()) return 0, nil, err case linux.F_SETOWN_EX: var owner linux.FOwnerEx - _, err := t.CopyIn(args[2].Pointer(), &owner) + _, err := owner.CopyIn(t, args[2].Pointer()) if err != nil { return 0, nil, err } @@ -286,7 +286,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return err } diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 38778a388..2806c3f6f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -34,20 +35,20 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Handle ioctls that apply to all FDs. switch args[1].Int() { case linux.FIONCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -60,7 +61,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -82,12 +83,12 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall who = owner.PID } } - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), who) return 0, nil, err case linux.FIOSETOWN, linux.SIOCSPGRP: var who int32 - if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &who); err != nil { return 0, nil, err } ownerType := int32(linux.F_OWNER_PID) diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index dc05c2994..9d9dbf775 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" @@ -85,6 +86,17 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with an anonymous tmpfs file. + opts.Offset = 0 + file, err := tmpfs.NewZeroFile(t, t.Credentials(), t.Kernel().ShmMount(), opts.Length) + if err != nil { + return 0, nil, err + } + defer file.DecRef(t) + if err := file.ConfigureMMap(t, &opts); err != nil { + return 0, nil, err + } } rv, err := t.MemoryManager().MMap(t, opts) diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index 9b4848d9e..ee38fdca0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -51,9 +52,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if err != nil { return err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { + if _, file := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index 79ad64039..c22e4ce54 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -165,7 +165,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -192,7 +192,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -205,7 +205,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -332,19 +332,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } @@ -497,6 +497,12 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return n, nil, err } +// +marshal +type sigSetWithSize struct { + sigsetAddr uint64 + sizeofSigset uint64 +} + // Pselect implements linux syscall pselect(2). func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { nfds := int(args[0].Int()) // select(2) uses an int. @@ -538,12 +544,6 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return n, nil, err } -// +marshal -type sigSetWithSize struct { - sigsetAddr uint64 - sizeofSigset uint64 -} - // copyTimespecInToDuration copies a Timespec from the untrusted app range, // validates it and converts it to a Duration. // diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 5e6eb13ba..1ee37e5a8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -346,7 +346,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opt return nil } var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 { @@ -410,7 +410,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op return nil } var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Nsec != linux.UTIME_OMIT { diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index a5032657a..bfae6b7e9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -30,8 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -66,10 +66,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -77,6 +77,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -105,30 +107,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -147,10 +133,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -159,7 +145,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -172,7 +158,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -250,9 +237,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { + if _, file := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } @@ -459,8 +446,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -474,7 +461,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } @@ -736,7 +723,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -751,7 +738,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -783,7 +770,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -820,17 +807,17 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -999,7 +986,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1014,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1025,7 +1012,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 68ce94778..bf5c1171f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -18,9 +18,12 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -88,7 +91,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inFile.Options().DenyPRead { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil { return 0, nil, err } if inOffset < 0 { @@ -103,7 +106,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if outFile.Options().DenyPWrite { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil { return 0, nil, err } if outOffset < 0 { @@ -144,11 +147,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal panic("at least one end of splice must be a pipe") } - if n == 0 && err == io.EOF { - // We reached the end of the file. Eat the error and exit the loop. - err = nil - break - } if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } @@ -159,25 +157,26 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Copy updated offsets out. if inOffsetPtr != 0 { - if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, inOffsetPtr, inOffset); err != nil { return 0, nil, err } } if outOffsetPtr != 0 { - if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, outOffsetPtr, outOffset); err != nil { return 0, nil, err } } - if n == 0 { - return 0, nil, err + if n != 0 { + // On Linux, inotify behavior is not very consistent with splice(2). We try + // our best to emulate Linux for very basic calls to splice, where for some + // reason, events are generated for output files, but not input files. + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) } - // On Linux, inotify behavior is not very consistent with splice(2). We try - // our best to emulate Linux for very basic calls to splice, where for some - // reason, events are generated for output files, but not input files. - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - return uintptr(n), nil, nil + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile) } // Tee implements Linux syscall tee(2). @@ -249,11 +248,20 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo break } } - if n == 0 { - return 0, nil, err + + if n != 0 { + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) + + // If a partial write is completed, the error is dropped. Log it here. + if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + log.Debugf("tee completed a partial write with error: %v", err) + err = nil + } } - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - return uintptr(n), nil, nil + + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "tee", inFile) } // Sendfile implements linux system call sendfile(2). @@ -302,9 +310,12 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if inFile.Options().DenyPRead { return 0, nil, syserror.ESPIPE } - if _, err := t.CopyIn(offsetAddr, &offset); err != nil { + var offsetP primitive.Int64 + if _, err := offsetP.CopyIn(t, offsetAddr); err != nil { return 0, nil, err } + offset = int64(offsetP) + if offset < 0 { return 0, nil, syserror.EINVAL } @@ -343,11 +354,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc for n < count { var spliceN int64 spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count) - if spliceN == 0 && err == io.EOF { - // We reached the end of the file. Eat the error and exit the loop. - err = nil - break - } if offset != -1 { offset += spliceN } @@ -370,13 +376,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) } - if readN == 0 && err != nil { - if err == io.EOF { - // We reached the end of the file. Eat the error before exiting the loop. - err = nil - } - break - } n += readN // Write all of the bytes that we read. This may need @@ -390,16 +389,21 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc err = dw.waitForOut(t) } if err != nil { - // We didn't complete the write. Only - // report the bytes that were actually - // written, and rewind the offset. + // We didn't complete the write. Only report the bytes that were actually + // written, and rewind offsets as needed. notWritten := int64(len(wbuf)) n -= notWritten - if offset != -1 { - // TODO(gvisor.dev/issue/3779): The inFile offset will be incorrect if we - // roll back, because it has already been advanced by the full amount. - // Merely seeking on inFile does not work, because there may be concurrent - // file operations. + if offset == -1 { + // We modified the offset of the input file itself during the read + // operation. Rewind it. + if _, seekErr := inFile.Seek(t, -notWritten, linux.SEEK_CUR); seekErr != nil { + // Log the error but don't return it, since the write has already + // completed successfully. + log.Warningf("failed to roll back input file offset: %v", seekErr) + } + } else { + // The sendfile call was provided an offset parameter that should be + // adjusted to reflect the number of bytes sent. Rewind it. offset -= notWritten } break @@ -416,18 +420,26 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if offsetAddr != 0 { // Copy out the new offset. - if _, err := t.CopyOut(offsetAddr, offset); err != nil { + offsetP := primitive.Uint64(offset) + if _, err := offsetP.CopyOut(t, offsetAddr); err != nil { return 0, nil, err } } - if n == 0 { - return 0, nil, err + if n != 0 { + inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) + + if err != nil && err != io.EOF && err != syserror.ErrWouldBlock { + // If a partial write is completed, the error is dropped. Log it here. + log.Debugf("sendfile completed a partial write with error: %v", err) + err = nil + } } - inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) - outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) - return uintptr(n), nil, nil + // We can only pass a single file to handleIOError, so pick inFile arbitrarily. + // This is used only for debugging purposes. + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) } // dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go index 7a26890ef..250870c03 100644 --- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go @@ -87,7 +87,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock()) @@ -97,7 +97,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tfd.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -122,6 +122,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tfd.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c576d9475..0df3bd449 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -93,16 +93,16 @@ func Override() { s.Table[165] = syscalls.Supported("mount", Mount) s.Table[166] = syscalls.Supported("umount2", Umount2) s.Table[187] = syscalls.Supported("readahead", Readahead) - s.Table[188] = syscalls.Supported("setxattr", Setxattr) + s.Table[188] = syscalls.Supported("setxattr", SetXattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr) - s.Table[191] = syscalls.Supported("getxattr", Getxattr) + s.Table[191] = syscalls.Supported("getxattr", GetXattr) s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr) s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr) - s.Table[194] = syscalls.Supported("listxattr", Listxattr) + s.Table[194] = syscalls.Supported("listxattr", ListXattr) s.Table[195] = syscalls.Supported("llistxattr", Llistxattr) s.Table[196] = syscalls.Supported("flistxattr", Flistxattr) - s.Table[197] = syscalls.Supported("removexattr", Removexattr) + s.Table[197] = syscalls.Supported("removexattr", RemoveXattr) s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr) s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) @@ -163,16 +163,16 @@ func Override() { // Override ARM64. s = linux.ARM64 - s.Table[5] = syscalls.Supported("setxattr", Setxattr) + s.Table[5] = syscalls.Supported("setxattr", SetXattr) s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) - s.Table[8] = syscalls.Supported("getxattr", Getxattr) + s.Table[8] = syscalls.Supported("getxattr", GetXattr) s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr) s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr) - s.Table[11] = syscalls.Supported("listxattr", Listxattr) + s.Table[11] = syscalls.Supported("listxattr", ListXattr) s.Table[12] = syscalls.Supported("llistxattr", Llistxattr) s.Table[13] = syscalls.Supported("flistxattr", Flistxattr) - s.Table[14] = syscalls.Supported("removexattr", Removexattr) + s.Table[14] = syscalls.Supported("removexattr", RemoveXattr) s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr) s.Table[17] = syscalls.Supported("getcwd", Getcwd) diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index ef99246ed..e05723ef9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -26,8 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Listxattr implements Linux syscall listxattr(2). -func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// ListXattr implements Linux syscall listxattr(2). +func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return listxattr(t, args, followFinalSymlink) } @@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml } defer tpop.Release(t) - names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) + names, err := t.Kernel().VFS().ListXattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { return 0, nil, err } @@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } defer file.DecRef(t) - names, err := file.Listxattr(t, uint64(size)) + names, err := file.ListXattr(t, uint64(size)) if err != nil { return 0, nil, err } @@ -85,8 +85,8 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return uintptr(n), nil, nil } -// Getxattr implements Linux syscall getxattr(2). -func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// GetXattr implements Linux syscall getxattr(2). +func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return getxattr(t, args, followFinalSymlink) } @@ -116,7 +116,7 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return 0, nil, err } - value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{ + value, err := t.Kernel().VFS().GetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetXattrOptions{ Name: name, Size: uint64(size), }) @@ -148,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)}) + value, err := file.GetXattr(t, &vfs.GetXattrOptions{Name: name, Size: uint64(size)}) if err != nil { return 0, nil, err } @@ -159,8 +159,8 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return uintptr(n), nil, nil } -// Setxattr implements Linux syscall setxattr(2). -func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// SetXattr implements Linux syscall setxattr(2). +func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, setxattr(t, args, followFinalSymlink) } @@ -199,7 +199,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return err } - return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{ + return t.Kernel().VFS().SetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), @@ -233,15 +233,15 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{ + return 0, nil, file.SetXattr(t, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), }) } -// Removexattr implements Linux syscall removexattr(2). -func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// RemoveXattr implements Linux syscall removexattr(2). +func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, removexattr(t, args, followFinalSymlink) } @@ -269,7 +269,7 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy return err } - return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name) + return t.Kernel().VFS().RemoveXattrAt(t, t.Credentials(), &tpop.pop, name) } // Fremovexattr implements Linux syscall fremovexattr(2). @@ -288,7 +288,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, err } - return 0, nil, file.Removexattr(t, name) + return 0, nil, file.RemoveXattr(t, name) } func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 5a0e3e6b5..bdfd3ca8f 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -52,6 +52,8 @@ const ( ) // anonFilesystemType implements FilesystemType. +// +// +stateify savable type anonFilesystemType struct{} // GetFilesystem implements FilesystemType.GetFilesystem. @@ -69,12 +71,15 @@ func (anonFilesystemType) Name() string { // // Since all Dentries in anonFilesystem are non-directories, all FilesystemImpl // methods that would require an anonDentry to be a directory return ENOTDIR. +// +// +stateify savable type anonFilesystem struct { vfsfs Filesystem devMinor uint32 } +// +stateify savable type anonDentry struct { vfsd Dentry @@ -245,32 +250,32 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements FilesystemImpl.ListxattrAt. -func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements FilesystemImpl.ListXattrAt. +func (fs *anonFilesystem) ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { if !rp.Done() { return nil, syserror.ENOTDIR } return nil, nil } -// GetxattrAt implements FilesystemImpl.GetxattrAt. -func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) { +// GetXattrAt implements FilesystemImpl.GetXattrAt. +func (fs *anonFilesystem) GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) { if !rp.Done() { return "", syserror.ENOTDIR } return "", syserror.ENOTSUP } -// SetxattrAt implements FilesystemImpl.SetxattrAt. -func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { +// SetXattrAt implements FilesystemImpl.SetXattrAt. +func (fs *anonFilesystem) SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error { if !rp.Done() { return syserror.ENOTDIR } return syserror.EPERM } -// RemovexattrAt implements FilesystemImpl.RemovexattrAt. -func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { +// RemoveXattrAt implements FilesystemImpl.RemoveXattrAt. +func (fs *anonFilesystem) RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error { if !rp.Done() { return syserror.ENOTDIR } diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index a69a5b2f1..320ab7ce1 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -89,6 +89,8 @@ func (d *Dentry) Impl() DentryImpl { // DentryImpl contains implementation details for a Dentry. Implementations of // DentryImpl should contain their associated Dentry by value as their first // field. +// +// +stateify savable type DentryImpl interface { // IncRef increments the Dentry's reference count. A Dentry with a non-zero // reference count must remain coherent with the state of the filesystem. diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go index 1e9dffc8f..dde2ad79b 100644 --- a/pkg/sentry/vfs/device.go +++ b/pkg/sentry/vfs/device.go @@ -22,6 +22,8 @@ import ( ) // DeviceKind indicates whether a device is a block or character device. +// +// +stateify savable type DeviceKind uint32 const ( @@ -44,6 +46,7 @@ func (kind DeviceKind) String() string { } } +// +stateify savable type devTuple struct { kind DeviceKind major uint32 diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 1b5af9f73..8f36c3e3b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -27,6 +27,8 @@ import ( var epollCycleMu sync.Mutex // EpollInstance represents an epoll instance, as described by epoll(7). +// +// +stateify savable type EpollInstance struct { vfsfd FileDescription FileDescriptionDefaultImpl @@ -38,11 +40,11 @@ type EpollInstance struct { // interest is the set of file descriptors that are registered with the // EpollInstance for monitoring. interest is protected by interestMu. - interestMu sync.Mutex + interestMu sync.Mutex `state:"nosave"` interest map[epollInterestKey]*epollInterest // mu protects fields in registered epollInterests. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // ready is the set of file descriptors that may be "ready" for I/O. Note // that this must be an ordered list, not a map: "If more than maxevents @@ -55,6 +57,7 @@ type EpollInstance struct { ready epollInterestList } +// +stateify savable type epollInterestKey struct { // file is the registered FileDescription. No reference is held on file; // instead, when the last reference is dropped, FileDescription.DecRef() @@ -67,6 +70,8 @@ type epollInterestKey struct { } // epollInterest represents an EpollInstance's interest in a file descriptor. +// +// +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. epoll *EpollInstance @@ -331,11 +336,9 @@ func (ep *EpollInstance) removeLocked(epi *epollInterest) { ep.mu.Unlock() } -// ReadEvents reads up to len(events) ready events into events and returns the -// number of events read. -// -// Preconditions: len(events) != 0. -func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { +// ReadEvents appends up to maxReady events to events and returns the updated +// slice of events. +func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent, maxEvents int) []linux.EpollEvent { i := 0 // Hot path: avoid defer. ep.mu.Lock() @@ -368,16 +371,16 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { requeue.PushBack(epi) } // Report ievents. - events[i] = linux.EpollEvent{ + events = append(events, linux.EpollEvent{ Events: ievents.ToLinux(), Data: epi.userData, - } + }) i++ - if i == len(events) { + if i == maxEvents { break } } ep.ready.PushBackList(&requeue) ep.mu.Unlock() - return i + return events } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 22a54fa48..1eba0270f 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -37,11 +37,13 @@ import ( // FileDescription methods require that a reference is held. // // FileDescription is analogous to Linux's struct file. +// +// +stateify savable type FileDescription struct { FileDescriptionRefs // flagsMu protects statusFlags and asyncHandler below. - flagsMu sync.Mutex + flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic @@ -56,7 +58,7 @@ type FileDescription struct { // epolls is the set of epollInterests registered for this FileDescription. // epolls is protected by epollMu. - epollMu sync.Mutex + epollMu sync.Mutex `state:"nosave"` epolls map[*epollInterest]struct{} // vd is the filesystem location at which this FileDescription was opened. @@ -88,6 +90,8 @@ type FileDescription struct { } // FileDescriptionOptions contains options to FileDescription.Init(). +// +// +stateify savable type FileDescriptionOptions struct { // If AllowDirectIO is true, allow O_DIRECT to be set on the file. AllowDirectIO bool @@ -101,7 +105,7 @@ type FileDescriptionOptions struct { // If UseDentryMetadata is true, calls to FileDescription methods that // interact with file and filesystem metadata (Stat, SetStat, StatFS, - // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling + // ListXattr, GetXattr, SetXattr, RemoveXattr) are implemented by calling // the corresponding FilesystemImpl methods instead of the corresponding // FileDescriptionImpl methods. // @@ -326,6 +330,9 @@ type FileDescriptionImpl interface { // Allocate grows the file to offset + length bytes. // Only mode == 0 is supported currently. // + // Allocate should return EISDIR on directories, ESPIPE on pipes, and ENODEV on + // other files where it is not supported. + // // Preconditions: The FileDescription was opened for writing. Allocate(ctx context.Context, mode, offset, length uint64) error @@ -420,19 +427,19 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // Listxattr returns all extended attribute names for the file. - Listxattr(ctx context.Context, size uint64) ([]string, error) + // ListXattr returns all extended attribute names for the file. + ListXattr(ctx context.Context, size uint64) ([]string, error) - // Getxattr returns the value associated with the given extended attribute + // GetXattr returns the value associated with the given extended attribute // for the file. - Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) + GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) - // Setxattr changes the value associated with the given extended attribute + // SetXattr changes the value associated with the given extended attribute // for the file. - Setxattr(ctx context.Context, opts SetxattrOptions) error + SetXattr(ctx context.Context, opts SetXattrOptions) error - // Removexattr removes the given extended attribute from the file. - Removexattr(ctx context.Context, name string) error + // RemoveXattr removes the given extended attribute from the file. + RemoveXattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error @@ -448,6 +455,8 @@ type FileDescriptionImpl interface { } // Dirent holds the information contained in struct linux_dirent64. +// +// +stateify savable type Dirent struct { // Name is the filename. Name string @@ -635,25 +644,25 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. return fd.impl.Ioctl(ctx, uio, args) } -// Listxattr returns all extended attribute names for the file represented by +// ListXattr returns all extended attribute names for the file represented by // fd. // // If the size of the list (including a NUL terminating byte after every entry) // would exceed size, ERANGE may be returned. Note that implementations // are free to ignore size entirely and return without error). In all cases, // if size is 0, the list should be returned without error, regardless of size. -func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { +func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) + names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) vfsObj.putResolvingPath(ctx, rp) return names, err } - names, err := fd.impl.Listxattr(ctx, size) + names, err := fd.impl.ListXattr(ctx, size) if err == syserror.ENOTSUP { // Linux doesn't actually return ENOTSUP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security @@ -664,57 +673,57 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string return names, err } -// Getxattr returns the value associated with the given extended attribute for +// GetXattr returns the value associated with the given extended attribute for // the file represented by fd. // // If the size of the return value exceeds opts.Size, ERANGE may be returned // (note that implementations are free to ignore opts.Size entirely and return // without error). In all cases, if opts.Size is 0, the value should be // returned without error, regardless of size. -func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) { +func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) (string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) + val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(ctx, rp) return val, err } - return fd.impl.Getxattr(ctx, *opts) + return fd.impl.GetXattr(ctx, *opts) } -// Setxattr changes the value associated with the given extended attribute for +// SetXattr changes the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error { +func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Setxattr(ctx, *opts) + return fd.impl.SetXattr(ctx, *opts) } -// Removexattr removes the given extended attribute from the file represented +// RemoveXattr removes the given extended attribute from the file represented // by fd. -func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { +func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name) + err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Removexattr(ctx, name) + return fd.impl.RemoveXattr(ctx, name) } // SyncFS instructs the filesystem containing fd to execute the semantics of diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 6b8b4ad49..48ca9de44 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -42,6 +42,8 @@ import ( // FileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl to obtain implementations of many FileDescriptionImpl // methods with default behavior analogous to Linux's. +// +// +stateify savable type FileDescriptionDefaultImpl struct{} // OnClose implements FileDescriptionImpl.OnClose analogously to @@ -57,7 +59,11 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err } // Allocate implements FileDescriptionImpl.Allocate analogously to -// fallocate called on regular file, directory or FIFO in Linux. +// fallocate called on an invalid type of file in Linux. +// +// Note that directories can rely on this implementation even though they +// should technically return EISDIR. Allocate should never be called for a +// directory, because it requires a writable fd. func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.ENODEV } @@ -134,34 +140,36 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } -// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// ListXattr implements FileDescriptionImpl.ListXattr analogously to // inode_operations::listxattr == NULL in Linux. -func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) { - // This isn't exactly accurate; see FileDescription.Listxattr. +func (FileDescriptionDefaultImpl) ListXattr(ctx context.Context, size uint64) ([]string, error) { + // This isn't exactly accurate; see FileDescription.ListXattr. return nil, syserror.ENOTSUP } -// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// GetXattr implements FileDescriptionImpl.GetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) { +func (FileDescriptionDefaultImpl) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { return "", syserror.ENOTSUP } -// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// SetXattr implements FileDescriptionImpl.SetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { +func (FileDescriptionDefaultImpl) SetXattr(ctx context.Context, opts SetXattrOptions) error { return syserror.ENOTSUP } -// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// RemoveXattr implements FileDescriptionImpl.RemoveXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { +func (FileDescriptionDefaultImpl) RemoveXattr(ctx context.Context, name string) error { return syserror.ENOTSUP } // DirectoryFileDescriptionDefaultImpl may be embedded by implementations of // FileDescriptionImpl that always represent directories to obtain // implementations of non-directory I/O methods that return EISDIR. +// +// +stateify savable type DirectoryFileDescriptionDefaultImpl struct{} // Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate. @@ -192,6 +200,8 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme // DentryMetadataFileDescriptionImpl may be embedded by implementations of // FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is // true to obtain implementations of Stat and SetStat that panic. +// +// +stateify savable type DentryMetadataFileDescriptionImpl struct{} // Stat implements FileDescriptionImpl.Stat. @@ -206,12 +216,16 @@ func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetSt // DynamicBytesSource represents a data source for a // DynamicBytesFileDescriptionImpl. +// +// +stateify savable type DynamicBytesSource interface { // Generate writes the file's contents to buf. Generate(ctx context.Context, buf *bytes.Buffer) error } // StaticData implements DynamicBytesSource over a static string. +// +// +stateify savable type StaticData struct { Data string } @@ -238,14 +252,24 @@ type WritableDynamicBytesSource interface { // // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. +// +// +stateify savable type DynamicBytesFileDescriptionImpl struct { data DynamicBytesSource // immutable - mu sync.Mutex // protects the following fields - buf bytes.Buffer + mu sync.Mutex `state:"nosave"` // protects the following fields + buf bytes.Buffer `state:".([]byte)"` off int64 lastRead int64 // offset at which the last Read, PRead, or Seek ended } +func (fd *DynamicBytesFileDescriptionImpl) saveBuf() []byte { + return fd.buf.Bytes() +} + +func (fd *DynamicBytesFileDescriptionImpl) loadBuf(p []byte) { + fd.buf.Write(p) +} + // SetDataSource must be called exactly once on fd before first use. func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) { fd.data = data @@ -378,6 +402,8 @@ func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.M // LockFD may be used by most implementations of FileDescriptionImpl.Lock* // functions. Caller must call Init(). +// +// +stateify savable type LockFD struct { locks *FileLocks } @@ -405,6 +431,8 @@ func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { // NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface // returning ENOLCK. +// +// +stateify savable type NoLockFD struct{} // LockBSD implements vfs.FileDescriptionImpl.LockBSD. diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 46851f638..c93d94634 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -416,26 +416,26 @@ type FilesystemImpl interface { // ResolvingPath.Resolve*(), then !rp.Done(). UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // ListxattrAt returns all extended attribute names for the file at rp. + // ListXattrAt returns all extended attribute names for the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // ListxattrAt returns ENOTSUP. + // ListXattrAt returns ENOTSUP. // // - If the size of the list (including a NUL terminating byte after every // entry) would exceed size, ERANGE may be returned. Note that // implementations are free to ignore size entirely and return without // error). In all cases, if size is 0, the list should be returned without // error, regardless of size. - ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) + ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) - // GetxattrAt returns the value associated with the given extended + // GetXattrAt returns the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, GetxattrAt + // - If extended attributes are not supported by the filesystem, GetXattrAt // returns ENOTSUP. // // - If an extended attribute named opts.Name does not exist, ENODATA is @@ -445,30 +445,30 @@ type FilesystemImpl interface { // returned (note that implementations are free to ignore opts.Size entirely // and return without error). In all cases, if opts.Size is 0, the value // should be returned without error, regardless of size. - GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) + GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) - // SetxattrAt changes the value associated with the given extended + // SetXattrAt changes the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, SetxattrAt + // - If extended attributes are not supported by the filesystem, SetXattrAt // returns ENOTSUP. // // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists, // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist, // ENODATA is returned. - SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error - // RemovexattrAt removes the given extended attribute from the file at rp. + // RemoveXattrAt removes the given extended attribute from the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // RemovexattrAt returns ENOTSUP. + // RemoveXattrAt returns ENOTSUP. // // - If name does not exist, ENODATA is returned. - RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. // @@ -506,6 +506,8 @@ type FilesystemImpl interface { // PrependPathAtVFSRootError is returned by implementations of // FilesystemImpl.PrependPath() when they encounter the contextual VFS root. +// +// +stateify savable type PrependPathAtVFSRootError struct{} // Error implements error.Error. @@ -516,6 +518,8 @@ func (PrependPathAtVFSRootError) Error() string { // PrependPathAtNonMountRootError is returned by implementations of // FilesystemImpl.PrependPath() when they encounter an independent ancestor // Dentry that is not the Mount root. +// +// +stateify savable type PrependPathAtNonMountRootError struct{} // Error implements error.Error. @@ -526,6 +530,8 @@ func (PrependPathAtNonMountRootError) Error() string { // PrependPathSyntheticError is returned by implementations of // FilesystemImpl.PrependPath() for which prepended names do not represent real // paths. +// +// +stateify savable type PrependPathSyntheticError struct{} // Error implements error.Error. diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go index f2298f7f6..bc19db1d5 100644 --- a/pkg/sentry/vfs/filesystem_type.go +++ b/pkg/sentry/vfs/filesystem_type.go @@ -55,10 +55,13 @@ type registeredFilesystemType struct { // RegisterFilesystemTypeOptions contains options to // VirtualFilesystem.RegisterFilesystem(). +// +// +stateify savable type RegisterFilesystemTypeOptions struct { - // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt() - // for which MountOptions.InternalMount == false to use this filesystem - // type. + // AllowUserMount determines whether users are allowed to mount a file system + // of this type, i.e. through mount(2). If AllowUserMount is true, allow calls + // to VirtualFilesystem.MountAt() for which MountOptions.InternalMount == false + // to use this filesystem type. AllowUserMount bool // If AllowUserList is true, make this filesystem type visible in diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 8882fa84a..2d27d9d35 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -27,6 +27,8 @@ import ( ) // Dentry is a required type parameter that is a struct with the given fields. +// +// +stateify savable type Dentry struct { // vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl. vfsd vfs.Dentry diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index aff220a61..3f0b8f45b 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -37,6 +37,8 @@ const inotifyEventBaseSize = 16 // // The way events are labelled appears somewhat arbitrary, but they must match // Linux so that IN_EXCL_UNLINK behaves as it does in Linux. +// +// +stateify savable type EventType uint8 // PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 42666eebf..55783d4eb 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -33,6 +33,8 @@ import ( // Note that in Linux these two types of locks are _not_ cooperative, because // race and deadlock conditions make merging them prohibitive. We do the same // and keep them oblivious to each other. +// +// +stateify savable type FileLocks struct { // bsd is a set of BSD-style advisory file wide locks, see flock(2). bsd fslock.Locks diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go index cc1e7d764..638b5d830 100644 --- a/pkg/sentry/vfs/memxattr/xattr.go +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -33,8 +33,8 @@ type SimpleExtendedAttributes struct { xattrs map[string]string } -// Getxattr returns the value at 'name'. -func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) { +// GetXattr returns the value at 'name'. +func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) { x.mu.RLock() value, ok := x.xattrs[opts.Name] x.mu.RUnlock() @@ -49,8 +49,8 @@ func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, return value, nil } -// Setxattr sets 'value' at 'name'. -func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { +// SetXattr sets 'value' at 'name'. +func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { x.mu.Lock() defer x.mu.Unlock() if x.xattrs == nil { @@ -72,8 +72,8 @@ func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { return nil } -// Listxattr returns all names in xattrs. -func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { +// ListXattr returns all names in xattrs. +func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { // Keep track of the size of the buffer needed in listxattr(2) for the list. listSize := 0 x.mu.RLock() @@ -90,8 +90,8 @@ func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { return names, nil } -// Removexattr removes the xattr at 'name'. -func (x *SimpleExtendedAttributes) Removexattr(name string) error { +// RemoveXattr removes the xattr at 'name'. +func (x *SimpleExtendedAttributes) RemoveXattr(name string) error { x.mu.Lock() defer x.mu.Unlock() if _, ok := x.xattrs[name]; !ok { diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index db5fb3bb1..dfc3ae6c0 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -18,14 +18,12 @@ import ( "bytes" "fmt" "math" - "path" "sort" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -67,7 +65,7 @@ type Mount struct { // // Invariant: key.parent != nil iff key.point != nil. key.point belongs to // key.parent.fs. - key mountKey + key mountKey `state:".(VirtualDentry)"` // ns is the namespace in which this Mount was mounted. ns is protected by // VirtualFilesystem.mountMu. @@ -154,13 +152,13 @@ type MountNamespace struct { // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *MountOptions) (*MountNamespace, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { ctx.Warningf("Unknown filesystem type: %s", fsTypeName) return nil, syserror.ENODEV } - fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) + fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return nil, err } @@ -169,7 +167,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth mountpoints: make(map[*Dentry]uint32), } mntns.EnableLeakCheck() - mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{}) + mntns.root = newMount(vfs, fs, root, mntns, opts) return mntns, nil } @@ -347,6 +345,7 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti return nil } +// +stateify savable type umountRecursiveOptions struct { // If eager is true, ensure that future calls to Mount.tryIncMountedRef() // on umounted mounts fail. @@ -416,7 +415,7 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns } } mnt.IncRef() // dropped by callers of umountRecursiveLocked - mnt.storeKey(vd) + mnt.setKey(vd) if vd.mount.children == nil { vd.mount.children = make(map[*Mount]struct{}) } @@ -441,13 +440,13 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns // * vfs.mounts.seq must be in a writer critical section. // * mnt.parent() != nil. func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { - vd := mnt.loadKey() + vd := mnt.getKey() if checkInvariants { if vd.mount != nil { panic("VFS.disconnectLocked called on disconnected mount") } } - mnt.storeKey(VirtualDentry{}) + mnt.loadKey(VirtualDentry{}) delete(vd.mount.children, mnt) atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1 mnt.ns.mountpoints[vd.dentry]-- @@ -740,11 +739,23 @@ func (mntns *MountNamespace) Root() VirtualDentry { // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -755,7 +766,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMounts: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -789,11 +800,25 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()) or + // vfs.StatAt() (=> FilesystemImpl.StatAt()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + + creds := auth.CredentialsFromContext(ctx) for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -804,7 +829,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMountInfo: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -817,9 +842,10 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo Root: mntRootVD, Start: mntRootVD, } - statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{}) + statx, err := vfs.StatAt(ctx, creds, pop, &StatOptions{}) if err != nil { // Well that's not good. Ignore this mount. + ctx.Warningf("VFS.GenerateProcMountInfo: failed to stat mount root %+v: %v", mnt.root, err) break } @@ -831,6 +857,9 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo fmt.Fprintf(buf, "%d ", mnt.ID) // (2) Parent ID (or this ID if there is no parent). + // Note that even if the call to mnt.parent() races with Mount + // destruction (which is possible since we're not holding vfs.mountMu), + // its Mount.ID will still be valid. pID := mnt.ID if p := mnt.parent(); p != nil { pID = p.ID @@ -879,30 +908,6 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo } } -// MakeSyntheticMountpoint creates parent directories of target if they do not -// exist and attempts to create a directory for the mountpoint. If a -// non-directory file already exists there then we allow it. -func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error { - mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} - - // Make sure the parent directory of target exists. - if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil { - return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err) - } - - // Attempt to mkdir the final component. If a file (of any type) exists - // then we let allow mounting on top of that because we do not require the - // target to be an existing directory, unlike Linux mount(2). - if err := vfs.MkdirAt(ctx, creds, &PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(target), - }, mkdirOpts); err != nil && err != syserror.EEXIST { - return fmt.Errorf("failed to create mountpoint %q: %w", target, err) - } - return nil -} - // manglePath replaces ' ', '\t', '\n', and '\\' with their octal equivalents. // See Linux fs/seq_file.c:mangle_path. func manglePath(p string) string { diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index 3335e4057..cb8c56bd3 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -38,7 +38,7 @@ func TestMountTableInsertLookup(t *testing.T) { mt.Init() mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) + mount.setKey(VirtualDentry{&Mount{}, &Dentry{}}) mt.Insert(mount) if m := mt.Lookup(mount.parent(), mount.point()); m != mount { @@ -79,7 +79,7 @@ const enableComparativeBenchmarks = false func newBenchMount() *Mount { mount := &Mount{} - mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}}) + mount.loadKey(VirtualDentry{&Mount{}, &Dentry{}}) return mount } @@ -94,7 +94,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { for i := 0; i < numMounts; i++ { mount := newBenchMount() mt.Insert(mount) - keys = append(keys, mount.loadKey()) + keys = append(keys, mount.saveKey()) } var ready sync.WaitGroup @@ -146,7 +146,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := mount.loadKey() + key := mount.saveKey() ms[key] = mount keys = append(keys, key) } @@ -201,7 +201,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { keys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { mount := newBenchMount() - key := mount.loadKey() + key := mount.getKey() ms.Store(key, mount) keys = append(keys, key) } @@ -283,7 +283,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms[mount.loadKey()] = mount + ms[mount.getKey()] = mount } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -318,7 +318,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { var ms sync.Map for i := 0; i < numMounts; i++ { mount := newBenchMount() - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } negkeys := make([]VirtualDentry, 0, numMounts) for i := 0; i < numMounts; i++ { @@ -372,7 +372,7 @@ func BenchmarkMountMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms[mount.loadKey()] = mount + ms[mount.saveKey()] = mount } } @@ -392,7 +392,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) { b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } } @@ -425,13 +425,13 @@ func BenchmarkMountMapRemove(b *testing.B) { ms := make(map[VirtualDentry]*Mount) for i := range mounts { mount := mounts[i] - ms[mount.loadKey()] = mount + ms[mount.saveKey()] = mount } b.ResetTimer() for i := range mounts { mount := mounts[i] - delete(ms, mount.loadKey()) + delete(ms, mount.saveKey()) } } @@ -447,12 +447,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) { var ms sync.Map for i := range mounts { mount := mounts[i] - ms.Store(mount.loadKey(), mount) + ms.Store(mount.saveKey(), mount) } b.ResetTimer() for i := range mounts { mount := mounts[i] - ms.Delete(mount.loadKey()) + ms.Delete(mount.saveKey()) } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index da2a2e9c4..b7d122d22 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -34,6 +34,8 @@ import ( // structurally identical to VirtualDentry, but stores its fields as // unsafe.Pointer since mutators synchronize with VFS path traversal using // seqcounts. +// +// This is explicitly not savable. type mountKey struct { parent unsafe.Pointer // *Mount point unsafe.Pointer // *Dentry @@ -47,19 +49,23 @@ func (mnt *Mount) point() *Dentry { return (*Dentry)(atomic.LoadPointer(&mnt.key.point)) } -func (mnt *Mount) loadKey() VirtualDentry { +func (mnt *Mount) getKey() VirtualDentry { return VirtualDentry{ mount: mnt.parent(), dentry: mnt.point(), } } +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // Invariant: mnt.key.parent == nil. vd.Ok(). -func (mnt *Mount) storeKey(vd VirtualDentry) { +func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). @@ -92,6 +98,7 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 + // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index dfc8573fd..bc79e5ecc 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -21,6 +21,8 @@ import ( // GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and // FilesystemImpl.GetDentryAt(). +// +// +stateify savable type GetDentryOptions struct { // If CheckSearchable is true, FilesystemImpl.GetDentryAt() must check that // the returned Dentry is a directory for which creds has search @@ -30,6 +32,8 @@ type GetDentryOptions struct { // MkdirOptions contains options to VirtualFilesystem.MkdirAt() and // FilesystemImpl.MkdirAt(). +// +// +stateify savable type MkdirOptions struct { // Mode is the file mode bits for the created directory. Mode linux.FileMode @@ -56,6 +60,8 @@ type MkdirOptions struct { // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). +// +// +stateify savable type MknodOptions struct { // Mode is the file type and mode bits for the created file. Mode linux.FileMode @@ -72,6 +78,8 @@ type MknodOptions struct { // MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC. // MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers. +// +// +stateify savable type MountFlags struct { // NoExec is equivalent to MS_NOEXEC. NoExec bool @@ -93,6 +101,8 @@ type MountFlags struct { } // MountOptions contains options to VirtualFilesystem.MountAt(). +// +// +stateify savable type MountOptions struct { // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC. Flags MountFlags @@ -103,13 +113,17 @@ type MountOptions struct { // GetFilesystemOptions contains options to FilesystemType.GetFilesystem(). GetFilesystemOptions GetFilesystemOptions - // If InternalMount is true, allow the use of filesystem types for which - // RegisterFilesystemTypeOptions.AllowUserMount == false. + // InternalMount indicates whether the mount operation is coming from the + // application, i.e. through mount(2). If InternalMount is true, allow the use + // of filesystem types for which RegisterFilesystemTypeOptions.AllowUserMount + // == false. InternalMount bool } // OpenOptions contains options to VirtualFilesystem.OpenAt() and // FilesystemImpl.OpenAt(). +// +// +stateify savable type OpenOptions struct { // Flags contains access mode and flags as specified for open(2). // @@ -135,6 +149,8 @@ type OpenOptions struct { // ReadOptions contains options to FileDescription.PRead(), // FileDescriptionImpl.PRead(), FileDescription.Read(), and // FileDescriptionImpl.Read(). +// +// +stateify savable type ReadOptions struct { // Flags contains flags as specified for preadv2(2). Flags uint32 @@ -142,6 +158,8 @@ type ReadOptions struct { // RenameOptions contains options to VirtualFilesystem.RenameAt() and // FilesystemImpl.RenameAt(). +// +// +stateify savable type RenameOptions struct { // Flags contains flags as specified for renameat2(2). Flags uint32 @@ -153,6 +171,8 @@ type RenameOptions struct { // SetStatOptions contains options to VirtualFilesystem.SetStatAt(), // FilesystemImpl.SetStatAt(), FileDescription.SetStat(), and // FileDescriptionImpl.SetStat(). +// +// +stateify savable type SetStatOptions struct { // Stat is the metadata that should be set. Only fields indicated by // Stat.Mask should be set. @@ -174,6 +194,8 @@ type SetStatOptions struct { // BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() // and FilesystemImpl.BoundEndpointAt(). +// +// +stateify savable type BoundEndpointOptions struct { // Addr is the path of the file whose socket endpoint is being retrieved. // It is generally irrelevant: most endpoints are stored at a dentry that @@ -190,10 +212,12 @@ type BoundEndpointOptions struct { Addr string } -// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), -// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and -// FileDescriptionImpl.Getxattr(). -type GetxattrOptions struct { +// GetXattrOptions contains options to VirtualFilesystem.GetXattrAt(), +// FilesystemImpl.GetXattrAt(), FileDescription.GetXattr(), and +// FileDescriptionImpl.GetXattr(). +// +// +stateify savable +type GetXattrOptions struct { // Name is the name of the extended attribute to retrieve. Name string @@ -204,10 +228,12 @@ type GetxattrOptions struct { Size uint64 } -// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), -// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and -// FileDescriptionImpl.Setxattr(). -type SetxattrOptions struct { +// SetXattrOptions contains options to VirtualFilesystem.SetXattrAt(), +// FilesystemImpl.SetXattrAt(), FileDescription.SetXattr(), and +// FileDescriptionImpl.SetXattr(). +// +// +stateify savable +type SetXattrOptions struct { // Name is the name of the extended attribute being mutated. Name string @@ -221,6 +247,8 @@ type SetxattrOptions struct { // StatOptions contains options to VirtualFilesystem.StatAt(), // FilesystemImpl.StatAt(), FileDescription.Stat(), and // FileDescriptionImpl.Stat(). +// +// +stateify savable type StatOptions struct { // Mask is the set of fields in the returned Statx that the FilesystemImpl // or FileDescriptionImpl should provide. Bits are as in linux.Statx.Mask. @@ -238,6 +266,8 @@ type StatOptions struct { } // UmountOptions contains options to VirtualFilesystem.UmountAt(). +// +// +stateify savable type UmountOptions struct { // Flags contains flags as specified for umount2(2). Flags uint32 @@ -246,6 +276,8 @@ type UmountOptions struct { // WriteOptions contains options to FileDescription.PWrite(), // FileDescriptionImpl.PWrite(), FileDescription.Write(), and // FileDescriptionImpl.Write(). +// +// +stateify savable type WriteOptions struct { // Flags contains flags as specified for pwritev2(2). Flags uint32 diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 014b928ed..d48520d58 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -16,6 +16,7 @@ package vfs import ( "math" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -25,6 +26,8 @@ import ( ) // AccessTypes is a bitmask of Unix file permissions. +// +// +stateify savable type AccessTypes uint16 // Bits in AccessTypes. @@ -284,3 +287,40 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { } return size, nil } + +// CheckXattrPermissions checks permissions for extended attribute access. +// This is analogous to fs/xattr.c:xattr_permission(). Some key differences: +// * Does not check for read-only filesystem property. +// * Does not check inode immutability or append only mode. In both cases EPERM +// must be returned by filesystem implementations. +// * Does not do inode permission checks. Filesystem implementations should +// handle inode permission checks as they may differ across implementations. +func CheckXattrPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, name string) error { + switch { + case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX): + // The trusted.* namespace can only be accessed by privileged + // users. + if creds.HasCapability(linux.CAP_SYS_ADMIN) { + return nil + } + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + case strings.HasPrefix(name, linux.XATTR_USER_PREFIX): + // In the user.* namespace, only regular files and directories can have + // extended attributes. For sticky directories, only the owner and + // privileged users can write attributes. + filetype := mode.FileType() + if filetype != linux.ModeRegular && filetype != linux.ModeDirectory { + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + } + if filetype == linux.ModeDirectory && mode&linux.ModeSticky != 0 && ats.MayWrite() && !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + } + return nil +} diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index 3304372d9..e4fd55012 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -35,6 +35,8 @@ import ( // FilesystemImpl methods. // // ResolvingPath is loosely analogous to Linux's struct nameidata. +// +// +stateify savable type ResolvingPath struct { vfs *VirtualFilesystem root VirtualDentry // refs borrowed from PathOperation @@ -88,6 +90,7 @@ func init() { // so error "constants" are really mutable vars, necessitating somewhat // expensive interface object comparisons. +// +stateify savable type resolveMountRootOrJumpError struct{} // Error implements error.Error. @@ -95,6 +98,7 @@ func (resolveMountRootOrJumpError) Error() string { return "resolving mount root or jump" } +// +stateify savable type resolveMountPointError struct{} // Error implements error.Error. @@ -102,6 +106,7 @@ func (resolveMountPointError) Error() string { return "resolving mount point" } +// +stateify savable type resolveAbsSymlinkError struct{} // Error implements error.Error. diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index ec27562d6..5bd756ea5 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -163,6 +163,8 @@ func (vfs *VirtualFilesystem) Init(ctx context.Context) error { // PathOperation is passed to VFS methods by pointer to reduce memory copying: // it's somewhat large and should never escape. (Options structs are passed by // pointer to VFS and FileDescription methods for the same reason.) +// +// +stateify savable type PathOperation struct { // Root is the VFS root. References on Root are borrowed from the provider // of the PathOperation. @@ -297,6 +299,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential // MkdirAt creates a directory at the given path. func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mkdirat(dirfd, "", mode). if pop.Path.Absolute { return syserror.EEXIST } @@ -333,6 +337,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia // error from the syserror package. func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mknodat(dirfd, "", mode, dev). if pop.Path.Absolute { return syserror.EEXIST } @@ -518,6 +524,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti // RmdirAt removes the directory at the given path. func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", AT_REMOVEDIR). if pop.Path.Absolute { return syserror.EBUSY } @@ -599,6 +607,8 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti // SymlinkAt creates a symbolic link at the given path with the given target. func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with symlinkat(oldpath, newdirfd, ""). if pop.Path.Absolute { return syserror.EEXIST } @@ -631,6 +641,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent // UnlinkAt deletes the non-directory file at the given path. func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", 0). if pop.Path.Absolute { return syserror.EBUSY } @@ -662,12 +674,6 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti // BoundEndpointAt gets the bound endpoint at the given path, if one exists. func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) { - if !pop.Path.Begin.Ok() { - if pop.Path.Absolute { - return nil, syserror.ECONNREFUSED - } - return nil, syserror.ENOENT - } rp := vfs.getResolvingPath(creds, pop) for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) @@ -687,12 +693,12 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } -// ListxattrAt returns all extended attribute names for the file at the given +// ListXattrAt returns all extended attribute names for the file at the given // path. -func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { +func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { rp := vfs.getResolvingPath(creds, pop) for { - names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) + names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { vfs.putResolvingPath(ctx, rp) return names, nil @@ -712,12 +718,12 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede } } -// GetxattrAt returns the value associated with the given extended attribute +// GetXattrAt returns the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) { +func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetXattrOptions) (string, error) { rp := vfs.getResolvingPath(creds, pop) for { - val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) + val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(ctx, rp) return val, nil @@ -729,12 +735,12 @@ func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Creden } } -// SetxattrAt changes the value associated with the given extended attribute +// SetXattrAt changes the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { +func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetXattrOptions) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(ctx, rp) return nil @@ -746,11 +752,11 @@ func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Creden } } -// RemovexattrAt removes the given extended attribute from the file at rp. -func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { +// RemoveXattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { vfs.putResolvingPath(ctx, rp) return nil @@ -815,6 +821,30 @@ func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string return nil } +// MakeSyntheticMountpoint creates parent directories of target if they do not +// exist and attempts to create a directory for the mountpoint. If a +// non-directory file already exists there then we allow it. +func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error { + mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} + + // Make sure the parent directory of target exists. + if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil { + return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err) + } + + // Attempt to mkdir the final component. If a file (of any type) exists + // then we let allow mounting on top of that because we do not require the + // target to be an existing directory, unlike Linux mount(2). + if err := vfs.MkdirAt(ctx, creds, &PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(target), + }, mkdirOpts); err != nil && err != syserror.EEXIST { + return fmt.Errorf("failed to create mountpoint %q: %w", target, err) + } + return nil +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). diff --git a/pkg/state/types.go b/pkg/state/types.go index 215ef80f8..84aed8732 100644 --- a/pkg/state/types.go +++ b/pkg/state/types.go @@ -107,6 +107,14 @@ func lookupNameFields(typ reflect.Type) (string, []string, bool) { } return name, nil, true } + // Sanity check the type. + if raceEnabled { + if _, ok := reverseTypeDatabase[typ]; !ok { + // The type was not registered? Must be an embedded + // structure or something else. + return "", nil, false + } + } // Extract the name from the object. name := t.StateTypeName() fields := t.StateFields() @@ -313,6 +321,9 @@ var primitiveTypeDatabase = func() map[string]reflect.Type { // globalTypeDatabase is used for dispatching interfaces on decode. var globalTypeDatabase = map[string]reflect.Type{} +// reverseTypeDatabase is a reverse mapping. +var reverseTypeDatabase = map[reflect.Type]string{} + // Register registers a type. // // This must be called on init and only done once. @@ -358,4 +369,7 @@ func Register(t Type) { Failf("conflicting name for %T: matches interfaceType", t) } globalTypeDatabase[name] = typ + if raceEnabled { + reverseTypeDatabase[typ] = name + } } diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 4d47207f7..68535c3b1 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -38,6 +38,7 @@ go_library( "race_unsafe.go", "rwmutex_unsafe.go", "seqcount.go", + "spin_unsafe.go", "sync.go", ], marshal = False, diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go index eda6fb131..2184cb5ab 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/seqatomic_unsafe.go @@ -25,41 +25,35 @@ import ( type Value struct{} // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoad(seq *sync.SeqCount, ptr *Value) Value { for { - epoch := sc.BeginRead() - if sync.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break + if val, ok := SeqAtomicTryLoad(seq, seq.BeginRead(), ptr); ok { + return val } } - return val } // SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns // (unspecified, false). -func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value +// +//go:nosplit +func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (val Value, ok bool) { if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, so it + // can't help us here. Instead, call runtime.memmove directly, which is + // not instrumented by the race detector. sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { + // This is ~40% faster for short reads than going through memmove. val = *ptr } - return val, sc.ReadOk(epoch) + ok = seq.ReadOk(epoch) + return } func init() { diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go index a1e895352..2c5d3df99 100644 --- a/pkg/sync/seqcount.go +++ b/pkg/sync/seqcount.go @@ -8,7 +8,6 @@ package sync import ( "fmt" "reflect" - "runtime" "sync/atomic" ) @@ -43,9 +42,7 @@ type SeqCount struct { } // SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} +type SeqCountEpoch uint32 // We assume that: // @@ -83,12 +80,25 @@ type SeqCountEpoch struct { // using this pattern. Most users of SeqCount will need to use the // SeqAtomicLoad function template in seqatomic.go. func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } + return s.beginReadSlow() +} + +func (s *SeqCount) beginReadSlow() SeqCountEpoch { + i := 0 + for { + if canSpin(i) { + i++ + doSpin() + } else { + goyield() + } + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } } - return SeqCountEpoch{epoch} } // ReadOk returns true if the reader critical section initiated by a previous @@ -99,7 +109,7 @@ func (s *SeqCount) BeginRead() SeqCountEpoch { // Reader critical sections do not need to be explicitly terminated; the last // call to ReadOk is implicitly the end of the reader critical section. func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val + return atomic.LoadUint32(&s.epoch) == uint32(epoch) } // BeginWrite indicates the beginning of a writer critical section. diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/spin_unsafe.go new file mode 100644 index 000000000..cafb2d065 --- /dev/null +++ b/pkg/sync/spin_unsafe.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.17 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + _ "unsafe" // for go:linkname +) + +//go:linkname canSpin sync.runtime_canSpin +func canSpin(i int) bool + +//go:linkname doSpin sync.runtime_doSpin +func doSpin() + +//go:linkname goyield runtime.goyield +func goyield() diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index fe9f50169..f516c8e46 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -33,6 +33,7 @@ var ( EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) + ECONNABORTED = error(syscall.ECONNABORTED) ECONNREFUSED = error(syscall.ECONNREFUSED) ECONNRESET = error(syscall.ECONNRESET) EDEADLK = error(syscall.EDEADLK) diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go index 29719752e..7036467c4 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/syserror/syserror_test.go @@ -24,27 +24,20 @@ import ( var globalError error -func returnErrnoAsError() error { - return syscall.EINVAL -} - -func returnError() error { - return syserror.EINVAL -} - -func BenchmarkReturnErrnoAsError(b *testing.B) { +func BenchmarkAssignErrno(b *testing.B) { for i := b.N; i > 0; i-- { - returnErrnoAsError() + globalError = syscall.EINVAL } } -func BenchmarkReturnError(b *testing.B) { +func BenchmarkAssignError(b *testing.B) { for i := b.N; i > 0; i-- { - returnError() + globalError = syserror.EINVAL } } func BenchmarkCompareErrno(b *testing.B) { + globalError = syscall.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syscall.EINVAL { @@ -54,6 +47,7 @@ func BenchmarkCompareErrno(b *testing.B) { } func BenchmarkCompareError(b *testing.B) { + globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syserror.EINVAL { @@ -63,6 +57,7 @@ func BenchmarkCompareError(b *testing.B) { } func BenchmarkSwitchErrno(b *testing.B) { + globalError = syscall.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { @@ -77,6 +72,7 @@ func BenchmarkSwitchErrno(b *testing.B) { } func BenchmarkSwitchError(b *testing.B) { + globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 68a954a10..4f551cd92 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { // Accept implements net.Conn.Accept. func (l *TCPListener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() + n, wq, err := l.ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { defer l.wq.EventUnregister(&waitEntry) for { - n, wq, err = l.ep.Accept() + n, wq, err = l.ep.Accept(nil) if err != tcpip.ErrWouldBlock { break diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index c975ad9cf..12b061def 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 563bc78ea..c326fab54 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -14,6 +14,8 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "view_test.go", + ], library = ":buffer", ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index b769094dc..71b2f1bda 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -118,18 +118,62 @@ func TTL(ttl uint8) NetworkChecker { v = ip.HopLimit() } if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) + } + } +} + +// IPFullLength creates a checker for the full IP packet length. The +// expected size is checked against both the Total Length in the +// header and the number of bytes received. +func IPFullLength(packetLength uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + var v uint16 + var l uint16 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TotalLength() + l = uint16(len(ip)) + case header.IPv6: + v = ip.PayloadLength() + header.IPv6FixedHeaderSize + l = uint16(len(ip)) + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) + } + if l != packetLength { + t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) + } + if v != packetLength { + t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) + } + } +} + +// IPv4HeaderLength creates a checker that checks the IPv4 Header length. +func IPv4HeaderLength(headerLength int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + switch ip := h[0].(type) { + case header.IPv4: + if hl := ip.HeaderLength(); hl != uint8(headerLength) { + t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) + } + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) } } } // PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { +func PayloadLen(payloadLength int) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) + if l := len(h[0].Payload()); l != payloadLength { + t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) } } } @@ -143,7 +187,7 @@ func FragmentOffset(offset uint16) NetworkChecker { switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) } } } @@ -158,7 +202,7 @@ func FragmentFlags(flags uint8) NetworkChecker { switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) } } } @@ -208,7 +252,7 @@ func TOS(tos uint8, label uint32) NetworkChecker { t.Helper() if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) } } } @@ -234,7 +278,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { t.Helper() if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) @@ -261,7 +305,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -297,7 +341,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) @@ -316,7 +360,7 @@ func SrcPort(port uint16) TransportChecker { t.Helper() if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got = %d, want = %d", p, port) } } } @@ -327,7 +371,7 @@ func DstPort(port uint16) TransportChecker { t.Helper() if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got = %d, want = %d", p, port) } } } @@ -339,7 +383,7 @@ func NoChecksum(noChecksum bool) TransportChecker { udp, ok := h.(header.UDP) if !ok { - return + t.Fatalf("UDP header not found in h: %T", h) } if b := udp.Checksum() == 0; b != noChecksum { @@ -348,50 +392,84 @@ func NoChecksum(noChecksum bool) TransportChecker { } } -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { +// TCPSeqNum creates a checker that checks the sequence number. +func TCPSeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) } } } -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { +// TCPAckNum creates a checker that checks the ack number. +func TCPAckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got = %d, want = %d", s, seq) } } } -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { +// TCPWindow creates a checker that checks the tcp window. +func TCPWindow(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in hdr : %T", h) } if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) + t.Errorf("Bad window, got %d, want %d", w, window) + } + } +} + +// TCPWindowGreaterThanEq creates a checker that checks that the TCP window +// is greater than or equal to the provided value. +func TCPWindowGreaterThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w < window { + t.Errorf("Bad window, got %d, want > %d", w, window) + } + } +} + +// TCPWindowLessThanEq creates a checker that checks that the tcp window +// is less than or equal to the provided value. +func TCPWindowLessThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w > window { + t.Errorf("Bad window, got %d, want < %d", w, window) } } } @@ -403,7 +481,7 @@ func TCPFlags(flags uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); f != flags { @@ -420,7 +498,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); (f & mask) != (flags & mask) { @@ -458,7 +536,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) } foundMSS = true i += 4 @@ -468,7 +546,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } v := int(opts[i+2]) if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) } foundWS = true i += 3 @@ -517,7 +595,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { t.Errorf("SACKPermitted option not found. Options: %x", opts) @@ -555,7 +633,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -575,13 +653,13 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) } } } @@ -645,7 +723,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) } } } @@ -690,10 +768,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -705,10 +783,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. +func ICMPv4Ident(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Ident(); got != want { + t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. +func ICMPv4Seq(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Sequence(); got != want { + t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. +// This assumes that the payload exactly makes up the rest of the slice. +func ICMPv4Checksum() TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + heldChecksum := icmpv4.Checksum() + icmpv4.SetChecksum(0) + newChecksum := ^header.Checksum(icmpv4, 0) + icmpv4.SetChecksum(heldChecksum) + if heldChecksum != newChecksum { + t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) + } + } +} + +// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. +func ICMPv4Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + payload := icmpv4.Payload() + if diff := cmp.Diff(payload, want); diff != "" { + t.Errorf("got ICMP payload mismatch (-want +got):\n%s", diff) } } } @@ -748,10 +892,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -763,10 +907,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) } } } diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD new file mode 100644 index 000000000..114d43df3 --- /dev/null +++ b/pkg/tcpip/faketime/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "faketime", + srcs = ["faketime.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@com_github_dpjacques_clockwork//:go_default_library", + ], +) + +go_test( + name = "faketime_test", + size = "small", + srcs = [ + "faketime_test.go", + ], + deps = [ + "//pkg/tcpip/faketime", + ], +) diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/faketime/faketime.go index 92c8cb534..f7a4fbde1 100644 --- a/pkg/tcpip/stack/fake_time_test.go +++ b/pkg/tcpip/faketime/faketime.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +// Package faketime provides a fake clock that implements tcpip.Clock interface. +package faketime import ( "container/heap" @@ -23,7 +24,29 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -type fakeClock struct { +// NullClock implements a clock that never advances. +type NullClock struct{} + +var _ tcpip.Clock = (*NullClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (*NullClock) NowNanoseconds() int64 { + return 0 +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (*NullClock) NowMonotonic() int64 { + return 0 +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { + return nil +} + +// ManualClock implements tcpip.Clock and only advances manually with Advance +// method. +type ManualClock struct { clock clockwork.FakeClock // mu protects the fields below. @@ -39,34 +62,35 @@ type fakeClock struct { waitGroups map[time.Time]*sync.WaitGroup } -func newFakeClock() *fakeClock { - return &fakeClock{ +// NewManualClock creates a new ManualClock instance. +func NewManualClock() *ManualClock { + return &ManualClock{ clock: clockwork.NewFakeClock(), times: &timeHeap{}, waitGroups: make(map[time.Time]*sync.WaitGroup), } } -var _ tcpip.Clock = (*fakeClock)(nil) +var _ tcpip.Clock = (*ManualClock)(nil) // NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (fc *fakeClock) NowNanoseconds() int64 { - return fc.clock.Now().UnixNano() +func (mc *ManualClock) NowNanoseconds() int64 { + return mc.clock.Now().UnixNano() } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (fc *fakeClock) NowMonotonic() int64 { - return fc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() int64 { + return mc.NowNanoseconds() } // AfterFunc implements tcpip.Clock.AfterFunc. -func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := fc.clock.Now().Add(d) - wg := fc.addWait(until) - return &fakeTimer{ - clock: fc, +func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := mc.clock.Now().Add(d) + wg := mc.addWait(until) + return &manualTimer{ + clock: mc, until: until, - timer: fc.clock.AfterFunc(d, func() { + timer: mc.clock.AfterFunc(d, func() { defer wg.Done() f() }), @@ -75,110 +99,113 @@ func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { // addWait adds an additional wait to the WaitGroup for parallel execution of // all work scheduled for t. Returns a reference to the WaitGroup modified. -func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { - fc.mu.RLock() - wg, ok := fc.waitGroups[t] - fc.mu.RUnlock() +func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { + mc.mu.RLock() + wg, ok := mc.waitGroups[t] + mc.mu.RUnlock() if ok { wg.Add(1) return wg } - fc.mu.Lock() - heap.Push(fc.times, t) - fc.mu.Unlock() + mc.mu.Lock() + heap.Push(mc.times, t) + mc.mu.Unlock() wg = &sync.WaitGroup{} wg.Add(1) - fc.mu.Lock() - fc.waitGroups[t] = wg - fc.mu.Unlock() + mc.mu.Lock() + mc.waitGroups[t] = wg + mc.mu.Unlock() return wg } // removeWait removes a wait from the WaitGroup for parallel execution of all // work scheduled for t. -func (fc *fakeClock) removeWait(t time.Time) { - fc.mu.RLock() - defer fc.mu.RUnlock() +func (mc *ManualClock) removeWait(t time.Time) { + mc.mu.RLock() + defer mc.mu.RUnlock() - wg := fc.waitGroups[t] + wg := mc.waitGroups[t] wg.Done() } -// advance executes all work that have been scheduled to execute within d from -// the current fake time. Blocks until all work has completed execution. -func (fc *fakeClock) advance(d time.Duration) { +// Advance executes all work that have been scheduled to execute within d from +// the current time. Blocks until all work has completed execution. +func (mc *ManualClock) Advance(d time.Duration) { // Block until all the work is done - until := fc.clock.Now().Add(d) + until := mc.clock.Now().Add(d) for { - fc.mu.Lock() - if fc.times.Len() == 0 { - fc.mu.Unlock() - return + mc.mu.Lock() + if mc.times.Len() == 0 { + mc.mu.Unlock() + break } - t := heap.Pop(fc.times).(time.Time) + t := heap.Pop(mc.times).(time.Time) if t.After(until) { // No work to do - heap.Push(fc.times, t) - fc.mu.Unlock() - return + heap.Push(mc.times, t) + mc.mu.Unlock() + break } - fc.mu.Unlock() + mc.mu.Unlock() - diff := t.Sub(fc.clock.Now()) - fc.clock.Advance(diff) + diff := t.Sub(mc.clock.Now()) + mc.clock.Advance(diff) - fc.mu.RLock() - wg := fc.waitGroups[t] - fc.mu.RUnlock() + mc.mu.RLock() + wg := mc.waitGroups[t] + mc.mu.RUnlock() wg.Wait() - fc.mu.Lock() - delete(fc.waitGroups, t) - fc.mu.Unlock() + mc.mu.Lock() + delete(mc.waitGroups, t) + mc.mu.Unlock() + } + if now := mc.clock.Now(); until.After(now) { + mc.clock.Advance(until.Sub(now)) } } -type fakeTimer struct { - clock *fakeClock +type manualTimer struct { + clock *ManualClock timer clockwork.Timer mu sync.RWMutex until time.Time } -var _ tcpip.Timer = (*fakeTimer)(nil) +var _ tcpip.Timer = (*manualTimer)(nil) // Reset implements tcpip.Timer.Reset. -func (ft *fakeTimer) Reset(d time.Duration) { - if !ft.timer.Reset(d) { +func (t *manualTimer) Reset(d time.Duration) { + if !t.timer.Reset(d) { return } - ft.mu.Lock() - defer ft.mu.Unlock() + t.mu.Lock() + defer t.mu.Unlock() - ft.clock.removeWait(ft.until) - ft.until = ft.clock.clock.Now().Add(d) - ft.clock.addWait(ft.until) + t.clock.removeWait(t.until) + t.until = t.clock.clock.Now().Add(d) + t.clock.addWait(t.until) } // Stop implements tcpip.Timer.Stop. -func (ft *fakeTimer) Stop() bool { - if !ft.timer.Stop() { +func (t *manualTimer) Stop() bool { + if !t.timer.Stop() { return false } - ft.mu.RLock() - defer ft.mu.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() - ft.clock.removeWait(ft.until) + t.clock.removeWait(t.until) return true } diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go new file mode 100644 index 000000000..c2704df2c --- /dev/null +++ b/pkg/tcpip/faketime/faketime_test.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package faketime_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" +) + +func TestManualClockAdvance(t *testing.T) { + const timeout = time.Millisecond + clock := faketime.NewManualClock() + start := clock.NowMonotonic() + clock.Advance(timeout) + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + t.Errorf("got = %d, want = %d", got, want) + } +} + +func TestManualClockAfterFunc(t *testing.T) { + const ( + timeout1 = time.Millisecond // timeout for counter1 + timeout2 = 2 * time.Millisecond // timeout for counter2 + ) + tests := []struct { + name string + advance time.Duration + wantCounter1 int + wantCounter2 int + }{ + { + name: "before timeout1", + advance: timeout1 - 1, + wantCounter1: 0, + wantCounter2: 0, + }, + { + name: "timeout1", + advance: timeout1, + wantCounter1: 1, + wantCounter2: 0, + }, + { + name: "timeout2", + advance: timeout2, + wantCounter1: 1, + wantCounter2: 1, + }, + { + name: "after timeout2", + advance: timeout2 + 1, + wantCounter1: 1, + wantCounter2: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + counter1 := 0 + counter2 := 0 + clock.AfterFunc(timeout1, func() { + counter1++ + }) + clock.AfterFunc(timeout2, func() { + counter2++ + }) + start := clock.NowMonotonic() + clock.Advance(test.advance) + if got, want := counter1, test.wantCounter1; got != want { + t.Errorf("got counter1 = %d, want = %d", got, want) + } + if got, want := counter2, test.wantCounter2; got != want { + t.Errorf("got counter2 = %d, want = %d", got, want) + } + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + t.Errorf("got elapsed = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index be03fb086..504408878 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -31,6 +31,27 @@ const ( // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 + // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an + // errant packet's transport layer that an ICMP error type packet should + // attempt to send as per RFC 792 (see each type) and RFC 1122 + // section 3.2.2 which states: + // Every ICMP error message includes the Internet header and at + // least the first 8 data octets of the datagram that triggered + // the error; more than 8 octets MAY be sent; this header and data + // MUST be unchanged from the received datagram. + // + // RFC 792 shows: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Code | Checksum | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | unused | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Internet Header + 64 bits of Original Data Datagram | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ICMPv4MinimumErrorPayloadSize = 8 + // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -39,15 +60,19 @@ const ( icmpv4ChecksumOffset = 2 // icmpv4MTUOffset is the offset of the MTU field - // in a ICMPv4FragmentationNeeded message. + // in an ICMPv4FragmentationNeeded message. icmpv4MTUOffset = 6 // icmpv4IdentOffset is the offset of the ident field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4IdentOffset = 4 + // icmpv4PointerOffset is the offset of the pointer field + // in an ICMPv4ParamProblem message. + icmpv4PointerOffset = 4 + // icmpv4SequenceOffset is the offset of the sequence field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4SequenceOffset = 6 ) @@ -72,15 +97,23 @@ const ( ICMPv4InfoReply ICMPv4Type = 16 ) +// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. +const ( + ICMPv4TTLExceeded ICMPv4Code = 0 +) + // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4NetUnreachable ICMPv4Code = 0 ICMPv4HostUnreachable ICMPv4Code = 1 ICMPv4ProtoUnreachable ICMPv4Code = 2 ICMPv4PortUnreachable ICMPv4Code = 3 ICMPv4FragmentationNeeded ICMPv4Code = 4 ) +// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed. +const ICMPv4UnusedCode ICMPv4Code = 0 + // Type is the ICMP type field. func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 20b01d8f4..6be31beeb 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -54,9 +54,17 @@ const ( // address. ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize - // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. + // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet. ICMPv6EchoMinimumSize = 8 + // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header, + // as per RFC 4443, Apendix A, item 4 and the errata. + // ... all ICMP error messages shall have exactly + // 32 bits of type-specific data, so that receivers can reliably find + // the embedded invoking packet even when they don't recognize the + // ICMP message Type. + ICMPv6ErrorHeaderSize = 8 + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize @@ -69,6 +77,10 @@ const ( // in an ICMPv6 message. icmpv6ChecksumOffset = 2 + // icmpv6PointerOffset is the offset of the pointer + // in an ICMPv6 Parameter problem message. + icmpv6PointerOffset = 4 + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 // PacketTooBig message. icmpv6MTUOffset = 4 @@ -89,9 +101,10 @@ const ( NDPHopLimit = 255 ) -// ICMPv6Type is the ICMP type field described in RFC 4443 and friends. +// ICMPv6Type is the ICMP type field described in RFC 4443. type ICMPv6Type byte +// Values for use in the Type field of ICMPv6 packet from RFC 4433. const ( ICMPv6DstUnreachable ICMPv6Type = 1 ICMPv6PacketTooBig ICMPv6Type = 2 @@ -109,7 +122,18 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// ICMPv6Code is the ICMP code field described in RFC 4443. +// IsErrorType returns true if the receiver is an ICMP error type. +func (typ ICMPv6Type) IsErrorType() bool { + // Per RFC 4443 section 2.1: + // ICMPv6 messages are grouped into two classes: error messages and + // informational messages. Error messages are identified as such by a + // zero in the high-order bit of their message Type field values. Thus, + // error messages have message types from 0 to 127; informational + // messages have message types from 128 to 255. + return typ&0x80 == 0 +} + +// ICMPv6Code is the ICMP Code field described in RFC 4443. type ICMPv6Code byte // ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 @@ -132,9 +156,14 @@ const ( // ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. const ( + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. ICMPv6ErroneousHeader ICMPv6Code = 0 - ICMPv6UnknownHeader ICMPv6Code = 1 - ICMPv6UnknownOption ICMPv6Code = 2 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 ) // ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use @@ -153,6 +182,16 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. +func (b ICMPv6) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) +} + // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 680eafd16..b07d9991d 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -80,7 +80,8 @@ type IPv4Fields struct { type IPv4 []byte const ( - // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + // IPv4MinimumSize is the minimum size of a valid IPv4 packet; + // i.e. a packet header with no options. IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given @@ -88,6 +89,16 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. + // + // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 + // header size). But RFC 791 section 3.2 discusses the design of the IPv4 + // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of + // 65,536 octets. Note that this is consistent with the the datagram total + // length field (of course, the header is counted in the total length and not + // in the fragments)." + IPv4MaximumPayloadSize = 65536 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that // the first fragment must carry when an IPv4 packet is fragmented. MinIPFragmentPayloadSize = 8 @@ -317,7 +328,7 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } // IsV4LoopbackAddress determines if the provided address is an IPv4 loopback -// address (belongs to 127.0.0.1/8 subnet). +// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3. func IsV4LoopbackAddress(addr tcpip.Address) bool { if len(addr) != IPv4AddressSize { return false diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index ea3823898..ef454b313 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -34,6 +34,9 @@ const ( hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -69,11 +72,15 @@ type IPv6 []byte const ( // IPv6MinimumSize is the minimum size of a valid IPv6 packet. - IPv6MinimumSize = 40 + IPv6MinimumSize = IPv6FixedHeaderSize // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 + // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per + // RFC 8200 Section 4.5. + IPv6MaximumPayloadSize = 65535 + // IPv6ProtocolNumber is IPv6's network protocol number. IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 3499d8399..583c2c5d3 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { // obtained before modification is no longer used. type IPv6OptionsExtHdrOptionsIterator struct { reader bytes.Reader + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset } // IPv6OptionUnknownAction is the action that must be taken if the processing @@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} // the options data, or an error occured. func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { for { + i.optionOffset = i.nextOptionOffset temp, err := i.reader.ReadByte() if err != nil { // If we can't read the first byte of a new option, then we know the @@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // know the option does not have Length and Data fields. End processing of // the Pad1 option and continue processing the buffer as a new option. if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 continue } @@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) } - // Special-case the variable length padding option to avoid a copy. - if id == ipv6PadNExtHdrOptionIdentifier { - // Do we have enough bytes in the reader for the PadN option? - if n := i.reader.Len(); n < int(length) { - // Reset the reader to effectively consume the remaining buffer. - i.reader.Reset(nil) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } - - // End processing of the PadN option and continue processing the buffer as - // a new option. continue - } - - bytes := make([]byte, length) - if n, err := io.ReadFull(&i.reader, bytes); err != nil { - // io.ReadFull may return io.EOF if i.reader has been exhausted. We use - // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the - // Length field found in the option. - if err == io.EOF { - err = io.ErrUnexpectedEOF + default: + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } - - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } } @@ -382,6 +396,29 @@ type IPv6PayloadIterator struct { // Indicates to the iterator that it should return the remaining payload as a // raw payload on the next call to Next. forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset } // MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing @@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa nextHdrIdentifier: nextHdrIdentifier, payload: payload.Clone(nil), // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. - reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + nextOffset: IPv6FixedHeaderSize, } } @@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Next is unable to return anything because the iterator has reached the end of // the payload, or an error occured. func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 // We could be forced to return i as a raw header when the previous header was // a fragment extension header as the data following the fragment extension // header may not be complete. @@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { return IPv6RoutingExtHdr(bytes), false, nil case IPv6FragmentExtHdrIdentifier: var data [6]byte - // We ignore the returned bytes becauase we know the fragment extension + // We ignore the returned bytes because we know the fragment extension // header specific data will fit in data. nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) if err != nil { @@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP if err != nil { return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) } + i.parseOffset++ var length uint8 length, err = i.reader.ReadByte() i.payload.TrimFront(1) + if err != nil { if fragmentHdr { return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) @@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP length = 0 } + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded if bytes == nil { bytes = make([]byte, bytesLen) diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD new file mode 100644 index 000000000..2adee9288 --- /dev/null +++ b/pkg/tcpip/header/parse/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "parse", + srcs = ["parse.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go new file mode 100644 index 000000000..5ca75c834 --- /dev/null +++ b/pkg/tcpip/header/parse/parse.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parse provides utilities to parse packets. +package parse + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// ARP populates pkt's network header with an ARP header found in +// pkt.Data. +// +// Returns true if the header was successfully parsed. +func ARP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.NetworkHeader().Consume(header.ARPSize) + if ok { + pkt.NetworkProtocolNumber = header.ARPProtocolNumber + } + return ok +} + +// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network +// header with the IPv4 header. +// +// Returns true if the header was successfully parsed. +func IPv4(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return false + } + ipHdr := header.IPv4(hdr) + + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return false + } + ipHdr = header.IPv4(hdr) + + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return true +} + +// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network +// header with the IPv6 header. +func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, 0, 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + var nextHdr tcpip.TransportProtocolNumber + var extensionsSize int + +traverseExtensions: + for { + extHdr, done, err := it.Next() + if err != nil { + break + } + + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + if fragID == 0 && fragOffset == 0 && !fragMore { + fragID = extHdr.ID() + fragOffset = extHdr.FragmentOffset() + fragMore = extHdr.More() + } + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + return nextHdr, fragID, fragOffset, fragMore, true +} + +// UDP parses a UDP packet found in pkt.Data and populates pkt's transport +// header with the UDP header. +// +// Returns true if the header was successfully parsed. +func UDP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + return ok +} + +// TCP parses a TCP packet found in pkt.Data and populates pkt's transport +// header with the TCP header. +// +// Returns true if the header was successfully parsed. +func TCP(pkt *stack.PacketBuffer) bool { + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset + } + + _, ok = pkt.TransportHeader().Consume(hdrLen) + pkt.TransportProtocolNumber = header.TCPProtocolNumber + return ok +} diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 9339d637f..98bdd29db 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "math" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -55,6 +56,10 @@ const ( // UDPMinimumSize is the minimum size of a valid UDP packet. UDPMinimumSize = 8 + // UDPMaximumSize is the maximum size of a valid UDP packet. The length field + // in the UDP header is 16 bits as per RFC 768. + UDPMaximumSize = math.MaxUint16 + // UDPProtocolNumber is UDP's transport protocol number. UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..6c410c5a6 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,14 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + srcs = [ + "errors_test.go", + ], + library = "rawfile", + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index a0a873c84..604868fd8 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error // *tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +// tcpip.ErrInvalidEndpointState (EINVAL). func TranslateErrno(e syscall.Errno) *tcpip.Error { - if err := translations[e]; err != nil { - return err + if e > 0 && e < syscall.Errno(len(translations)) { + if err := translations[e]; err != nil { + return err + } } return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go new file mode 100644 index 000000000..e4cdc66bd --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestTranslateErrno(t *testing.T) { + for _, test := range []struct { + errno syscall.Errno + translated *tcpip.Error + }{ + { + errno: syscall.Errno(0), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(maxErrno), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(514), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.EEXIST, + translated: tcpip.ErrDuplicateAddress, + }, + } { + got := TranslateErrno(test.errno) + if got != test.translated { + t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + } + } +} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 7cbc305e7..4aac12a8c 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 4fb127978..560477926 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var transProto uint8 src := tcpip.Address("unknown") dst := tcpip.Address("unknown") - id := 0 - size := uint16(0) + var size uint16 + var id uint32 var fragmentOffset uint16 var moreFragments bool - // Examine the packet using a new VV. Backing storage must not be written. - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - + // Clone the packet buffer to not modify the original. + // + // We don't clone the original packet buffer so that the new packet buffer + // does not have any of its headers set. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return } - ipv4 := header.IPv4(hdr) + + ipv4 := header.IPv4(pkt.NetworkHeader().View()) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) - id = int(ipv4.ID()) + id = uint32(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) + proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return } - ipv6 := header.IPv6(hdr) + + ipv6 := header.IPv6(pkt.NetworkHeader().View()) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() + transProto = uint8(proto) size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + id = fragID + moreFragments = fragMore + fragmentOffset = fragOffset case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { + if parse.ARP(pkt) { return } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + + arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( "%s arp %s (%s) -> %s (%s) valid:%t", prefix, @@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { + if ok := parse.UDP(pkt); !ok { break } - udp := header.UDP(hdr) + + udp := header.UDP(pkt.TransportHeader().View()) if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { + if ok := parse.TCP(pkt); !ok { break } - tcp := header.TCP(hdr) + + tcp := header.TCP(pkt.TransportHeader().View()) if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > vv.Size() && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) + if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 46083925c..59710352b 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -9,6 +9,7 @@ go_test( "ip_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -17,6 +18,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", ], diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 82c073e32..b40dde96b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7aaee08c4..b47a7be51 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -15,20 +15,15 @@ // Package arp implements the ARP network protocol. It is used to resolve // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. -// -// To use it in the networking stack, pass arp.NewProtocol() as one of the -// network protocols when calling stack.New. Then add an "arp" address to every -// NIC on the stack that should respond to ARP requests. That is: -// -// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { -// // handle err -// } package arp import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -40,15 +35,57 @@ const ( ProtocolAddress = tcpip.Address("arp") ) -// endpoint implements stack.NetworkEndpoint. +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - protocol *protocol - nicID tcpip.NICID + stack.AddressableEndpointState + + protocol *protocol + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } +func (e *endpoint) Enable() *tcpip.Error { + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + e.setEnabled(true) + return nil +} + +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +func (e *endpoint) setEnabled(v bool) { + if v { + atomic.StoreUint32(&e.enabled, 1) + } else { + atomic.StoreUint32(&e.enabled, 0) + } +} + +func (e *endpoint) Disable() { + e.setEnabled(false) +} + // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. func (e *endpoint) DefaultTTL() uint8 { return 0 @@ -59,19 +96,13 @@ func (e *endpoint) MTU() uint32 { return lmtu - uint32(e.MaxHeaderLength()) } -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported @@ -92,6 +123,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return @@ -102,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.nud == nil { - if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } @@ -136,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) return } @@ -168,14 +203,16 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ protocol: p, - nicID: nicID, - linkEP: sender, + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } + e.AddressableEndpointState.Init(e) + return e } // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. @@ -234,14 +271,14 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - _, ok = pkt.NetworkHeader().Consume(header.ARPSize) - if !ok { - return 0, false, false - } - return 0, false, true + return 0, false, parse.ARP(pkt) } // NewProtocol returns an ARP network protocol. -func NewProtocol() stack.NetworkProtocol { +// +// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" +// address must be added to every NIC that should respond to ARP requests. See +// ProtocolAddress for more details. +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 9c9a859e3..626af975a 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -176,8 +176,8 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, UseNeighborCache: useNeighborCache, @@ -442,7 +442,7 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := arp.NewProtocol() + p := arp.NewProtocol(nil) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 96c5f42f8..e247f06a4 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -43,5 +43,6 @@ go_test( library = ":fragmentation", deps = [ "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6a4843f92..e1909fab0 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -81,6 +81,8 @@ type Fragmentation struct { size int timeout time.Duration blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job } // NewFragmentation creates a new Fragmentation. @@ -97,7 +99,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -110,13 +112,17 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea blockSize = minBlockSize } - return &Fragmentation{ + f := &Fragmentation{ reassemblers: make(map[FragmentID]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, blockSize: blockSize, + clock: clock, } + f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) + + return f } // Process processes an incoming fragment belonging to an ID and returns a @@ -155,15 +161,17 @@ func (f *Fragmentation) Process( f.mu.Lock() r, ok := f.reassemblers[id] - if ok && r.tooOld(f.timeout) { - // This is very likely to be an id-collision or someone performing a slow-rate attack. - f.release(r) - ok = false - } if !ok { - r = newReassembler(id) + r = newReassembler(id, f.clock) f.reassemblers[id] = r + wasEmpty := f.rList.Empty() f.rList.PushFront(r) + if wasEmpty { + // If we have just pushed a first reassembler into an empty list, we + // should kickstart the release job. The release job will keep + // rescheduling itself until the list becomes empty. + f.releaseReassemblersLocked() + } } f.mu.Unlock() @@ -211,3 +219,27 @@ func (f *Fragmentation) release(r *reassembler) { f.size = 0 } } + +// releaseReassemblersLocked releases already-expired reassemblers, then +// schedules the job to call back itself for the remaining reassemblers if +// any. This function must be called with f.mu locked. +func (f *Fragmentation) releaseReassemblersLocked() { + now := f.clock.NowMonotonic() + for { + // The reassembler at the end of the list is the oldest. + r := f.rList.Back() + if r == nil { + // The list is empty. + break + } + elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + if f.timeout > elapsed { + // If the oldest reassembler has not expired, schedule the release + // job so that this function is called back when it has expired. + f.releaseJob.Schedule(f.timeout - elapsed) + break + } + // If the oldest reassembler has already expired, release it. + f.release(r) + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 416604659..189b223c5 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -21,6 +21,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) // vv is a helper to build VectorisedView from different strings. @@ -95,7 +96,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) @@ -131,25 +132,126 @@ func TestFragmentationProcess(t *testing.T) { } func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(minBlockSize, 1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err) + const ( + reassemblyTimeout = time.Millisecond + protocol = 0xff + ) + + type fragment struct { + first uint16 + last uint16 + more bool + data string } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") + + type event struct { + // name is a nickname of this event. + name string + + // clockAdvance is a duration to advance the clock. The clock advances + // before a fragment specified in the fragment field is processed. + clockAdvance time.Duration + + // fragment is a fragment to process. This can be nil if there is no + // fragment to process. + fragment *fragment + + // expectDone is true if the fragmentation instance should report the + // reassembly is done after the fragment is processd. + expectDone bool + + // sizeAfterEvent is the expected size of the fragmentation instance after + // the event. + sizeAfterEvent int + } + + half1 := &fragment{first: 0, last: 0, more: true, data: "0"} + half2 := &fragment{first: 1, last: 1, more: false, data: "1"} + + tests := []struct { + name string + events []event + }{ + { + name: "half1 and half2 are reassembled successfully", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2", + fragment: half2, + expectDone: true, + sizeAfterEvent: 0, + }, + }, + }, + { + name: "half1 timeout, half2 timeout", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half1 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + { + name: "half2", + fragment: half2, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half2 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + for _, event := range test.events { + clock.Advance(event.clockAdvance) + if frag := event.fragment; frag != nil { + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + if err != nil { + t.Fatalf("%s: f.Process failed: %s", event.name, err) + } + if done != event.expectDone { + t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) + } + } + if got, want := f.size, event.sizeAfterEvent; got != want { + t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + } + } + }) } } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. @@ -173,7 +275,7 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. @@ -268,7 +370,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{}) _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index f044867dc..9bb051a30 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -40,15 +40,15 @@ type reassembler struct { deleted int heap fragHeap done bool - creationTime time.Time + creationTime int64 } -func newReassembler(id FragmentID) *reassembler { +func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), heap: make(fragHeap, 0, 8), - creationTime: time.Now(), + creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, @@ -116,10 +116,6 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf return res, r.proto, true, consumed, nil } -func (r *reassembler) tooOld(timeout time.Duration) bool { - return time.Now().Sub(r.creationTime) > timeout -} - func (r *reassembler) checkDoneOrMark() bool { r.mu.Lock() prev := r.done diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index dff7c9dcb..a0a04a027 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -18,6 +18,8 @@ import ( "math" "reflect" "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) type updateHolesInput struct { @@ -94,7 +96,7 @@ var holesTestCases = []struct { func TestUpdateHoles(t *testing.T) { for _, c := range holesTestCases { - r := newReassembler(FragmentID{}) + r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, i := range c.in { r.updateHoles(i.first, i.last, i.more) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index e45dd17f8..6861cfdaf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -17,6 +17,7 @@ package ip_test import ( "testing" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -25,26 +26,35 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - nicID = 1 + localIPv4Addr = "\x0a\x00\x00\x01" + remoteIPv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) +var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv4Addr, + PrefixLen: 24, +} + +var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv6Addr, + PrefixLen: 120, +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -98,9 +108,10 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ + return stack.TransportPacketHandled } // DeliverTransportControlPacket is called by network endpoints after parsing @@ -194,8 +205,8 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv4.ProtocolNumber, local) @@ -210,8 +221,8 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv6.ProtocolNumber, local) @@ -224,33 +235,294 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStack(t *testing.T) *stack.Stack { +func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) e := channel.New(0, 1280, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err) + v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err) + v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) } + return s, e +} + +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s, _ := buildDummyStackWithLinkEndpoint(t) return s } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + tester testObject + + mu struct { + sync.RWMutex + disabled bool + } +} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (t *testInterface) Enabled() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return !t.mu.disabled +} + +func (t *testInterface) setEnabled(v bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.disabled = !v +} + +func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { + return &t.tester +} + +func TestSourceAddressValidation(t *testing.T) { + rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + srcAddress tcpip.Address + rxICMP func(*channel.Endpoint, tcpip.Address) + valid bool + }{ + { + name: "IPv4 valid", + srcAddress: "\x01\x02\x03\x04", + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 valid", + srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 multicast", + srcAddress: "\xe0\x00\x00\x01", + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv6 multicast", + srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + rxICMP: rxIPv6ICMP, + valid: false, + }, + { + name: "IPv4 broadcast", + srcAddress: header.IPv4Broadcast, + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv4 subnet broadcast", + srcAddress: func() tcpip.Address { + subnet := localIPv4AddrWithPrefix.Subnet() + return subnet.Broadcast() + }(), + rxICMP: rxIPv4ICMP, + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t) + test.rxICMP(e, test.srcAddress) + + var wantValid uint64 + if test.valid { + wantValid = 1 + } + + if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { + t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { + t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) + } + }) + } +} + +func TestEnableWhenNICDisabled(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protocolFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protocolFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var nic testInterface + nic.setEnabled(false) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, + }) + p := s.NetworkProtocolInstance(test.protoNum) + + // We pass nil for all parameters except the NetworkInterface and Stack + // since Enable only depends on these. + ep := p.NewEndpoint(&nic, nil, nil, nil) + + // The endpoint should initially be disabled, regardless the NIC's enabled + // status. + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Attempting to enable the endpoint while the NIC is disabled should + // fail. + nic.setEnabled(false) + if err := ep.Enable(); err != tcpip.ErrNotPermitted { + t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + } + // ep should consider the NIC's enabled status when determining its own + // enabled status so we "enable" the NIC to read just the endpoint's + // enabled status. + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Enabling the interface after the NIC has been enabled should succeed. + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + if !ep.Enabled() { + t.Fatal("got ep.Enabled() = false, want = true") + } + + // ep should consider the NIC's enabled status when determining its own + // enabled status. + nic.setEnabled(false) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Disabling the endpoint when the NIC is enabled should make the endpoint + // disabled. + nic.setEnabled(true) + ep.Disable() + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + }) + } +} + func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -266,12 +538,12 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv4Addr + nic.tester.dstAddr = remoteIPv4Addr + nic.tester.contents = payload - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -285,11 +557,21 @@ func TestIPv4Send(t *testing.T) { } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv4MinimumSize + 30 view := buffer.NewView(totalLen) ip := header.IPv4(view) @@ -298,8 +580,8 @@ func TestIPv4Receive(t *testing.T) { TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. @@ -308,12 +590,12 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[header.IPv4MinimumSize:totalLen] - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -324,8 +606,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -349,17 +631,26 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") + r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") if err != nil { t.Fatal(err) } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) @@ -371,7 +662,7 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, + DstAddr: localIPv4Addr, }) // Create the ICMP header. @@ -389,8 +680,8 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, + SrcAddr: localIPv4Addr, + DstAddr: remoteIPv4Addr, }) // Make payload be non-zero. @@ -400,27 +691,37 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv4MinimumSize + 24 frag1 := buffer.NewView(totalLen) @@ -432,8 +733,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { Protocol: 10, FragmentOffset: 0, Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -448,8 +749,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -457,12 +758,12 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -475,8 +776,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + if nic.tester.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) } // Send second segment. @@ -487,17 +788,26 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + // Allocate and initialize the payload view. payload := buffer.NewView(100) for i := 0; i < len(payload); i++ { @@ -511,12 +821,12 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv6Addr + nic.tester.dstAddr = remoteIPv6Addr + nic.tester.contents = payload - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -530,11 +840,20 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv6MinimumSize + 30 view := buffer.NewView(totalLen) ip := header.IPv6(view) @@ -542,8 +861,8 @@ func TestIPv6Receive(t *testing.T) { PayloadLength: uint16(totalLen - header.IPv6MinimumSize), NextHeader: 10, HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -552,12 +871,12 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[header.IPv6MinimumSize:totalLen] - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -569,8 +888,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -601,7 +920,7 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } r, err := buildIPv6Route( - localIpv6Addr, + localIPv6Addr, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", ) if err != nil { @@ -609,11 +928,20 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize @@ -627,7 +955,7 @@ func TestIPv6ReceiveControl(t *testing.T) { NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -643,8 +971,8 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: 100, NextHeader: 10, HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, }) // Build the fragmentation header if needed. @@ -666,19 +994,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index d142b4ffa..ee2c23e91 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,11 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", @@ -26,11 +28,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index b5659a36b..a8985ff5d 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,9 @@ package ipv4 import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -39,7 +42,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -105,11 +108,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -131,7 +134,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: dataVV, }) - + // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header + // information we will have to change this code to handle the ICMP header + // no longer being in the data buffer. + replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ @@ -193,3 +199,177 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Invalid.Increment() } } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv4 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv4 and sends it back to the remote device that sent +// the problematic packet. It incorporates as much of that packet as +// possible as well as any error metadata as is available. returnError +// expects pkt to hold a valid IPv4 packet as per the wire format. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + sent := r.Stats().ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // We check we are responding only when we are allowed to. + // See RFC 1812 section 4.3.2.7 (shown below). + // + // ========= + // 4.3.2.7 When Not to Send ICMP Errors + // + // An ICMP error message MUST NOT be sent as the result of receiving: + // + // o An ICMP error message, or + // + // o A packet which fails the IP header validation tests described in + // Section [5.2.2] (except where that section specifically permits + // the sending of an ICMP error message), or + // + // o A packet destined to an IP broadcast or IP multicast address, or + // + // o A packet sent as a Link Layer broadcast or multicast, or + // + // o Any fragment of a datagram other then the first fragment (i.e., a + // packet for which the fragment offset in the IP header is nonzero). + // + // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in + // response to a non-initial fragment, but it currently can not happen. + + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + return nil + } + + networkHeader := pkt.NetworkHeader().View() + transportHeader := pkt.TransportHeader().View() + + // Don't respond to icmp error packets. + if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + // TODO(gvisor.dev/issue/3810): + // Unfortunately the current stack pretty much always has ICMPv4 headers + // in the Data section of the packet but there is no guarantee that is the + // case. If this is the case grab the header to make it like all other + // packet types. When this is cleaned up the Consume should be removed. + if transportHeader.IsEmpty() { + var ok bool + transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) + if !ok { + return nil + } + } else if transportHeader.Size() < header.ICMPv4MinimumSize { + return nil + } + // We need to decide to explicitly name the packets we can respond to or + // the ones we can not respond to. The decision is somewhat arbitrary and + // if problems arise this could be reversed. It was judged less of a breach + // of protocol to not respond to unknown non-error packets than to respond + // to unknown error packets so we take the first approach. + switch header.ICMPv4(transportHeader).Type() { + case + header.ICMPv4EchoReply, + header.ICMPv4Echo, + header.ICMPv4Timestamp, + header.ICMPv4TimestampReply, + header.ICMPv4InfoRequest, + header.ICMPv4InfoReply: + default: + // Assume any type we don't know about may be an error type. + return nil + } + } + + // Now work out how much of the triggering packet we should return. + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes. + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 and RFC 792 where it mentioned that at + // least 8 bytes of the payload must be included. Today linux and other + // systems implement the RFC 1812 definition and not the original + // requirement. We treat 8 bytes as the minimum but will try send more. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + + if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + return nil + } + + payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + + // The buffers used by pkt may be used elsewhere in the system. + // For example, an AF_RAW or AF_PACKET socket may use what the transport + // protocol considers an unreachable destination. Thus we deep copy pkt to + // prevent multiple ownership and SR errors. The new copy is a vectorized + // view with the entire incoming IP packet reassembled and truncated as + // required. This is now the payload of the new ICMP packet and no longer + // considered a packet in its own right. + newHeader := append(buffer.View(nil), networkHeader...) + newHeader = append(newHeader, transportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) + payload.CapLength(payloadLen) + + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + switch reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + case *icmpReasonProtoUnreachable: + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + counter := sent.DstUnreachable + + if err := r.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index fa4ae2012..1f6e14c3f 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -12,20 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv4 contains the implementation of the ipv4 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv4.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv4 contains the implementation of the ipv4 network protocol. package ipv4 import ( + "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -50,22 +48,115 @@ const ( fragmentblockSize = 8 ) +var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() + +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack + + // enabled is set to 1 when the enpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + } } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, + linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, - stack: st, + } + e.mu.addressableEndpointState.Init(e) + return e +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Create an endpoint to receive broadcast packets on this interface. + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + return err + } + // We have no need for the address endpoint. + ep.DecRef() + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + return err +} + +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() +} + +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) + } + + // The address may have already been removed. + if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } } @@ -80,16 +171,6 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -197,6 +278,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, // Send out the fragment. if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(n - i)) return err } r.Stats().IP.PacketsSent.Increment() @@ -231,21 +313,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() return nil } - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than - // only NATted packets, but removing this check short circuits broadcasts - // before they are sent out to other hosts. + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -265,6 +350,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -285,20 +371,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } - nicName := e.stack.FindNICNameFromID(e.NICID()) + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -307,7 +397,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -318,12 +408,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err } n++ } r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, nil + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -373,25 +467,42 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } + if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 1122 section 3.2.1.3: + // When a host sends any datagram, the IP source address MUST + // be one of its own IP addresses (but not a broadcast or + // multicast address). + if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() return } @@ -404,11 +515,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < h.FragmentOffset() { + start := h.FragmentOffset() + // Drop the fragment if the size of the reassembled payload would exceed the + // maximum payload size. + // + // Note that this addition doesn't overflow even on 32bit architecture + // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // size). Otherwise the packet would've been rejected as invalid before + // reaching here. + if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return @@ -425,8 +540,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ID: uint32(h.ID()), Protocol: proto, }, - h.FragmentOffset(), - last, + start, + start+uint16(pkt.Data.Size())-1, h.More(), proto, pkt.Data, @@ -440,27 +555,165 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } } + + r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { + // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport + // headers, the setting of the transport number here should be + // unnecessary and removed. + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt) return } - r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, pkt) + + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } // Close cleans up resources associated with the endpoint. -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + e.disableLocked() + e.mu.addressableEndpointState.Cleanup() +} + +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.RemovePermanentAddress(addr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + + loopback := e.nic.IsLoopback() + addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + // IPv4 has a notion of a subnet broadcast address and considers the + // loopback interface bound to an address's whole subnet (on linux). + return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) + }) + if addressEndpoint != nil { + return addressEndpoint + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) + if err != nil { + // AddAddress only returns an error if the address is already assigned, + // but we just checked above if the address exists so we expect no error. + panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) + } + return addressEndpoint +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV4MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { - ids []uint32 - hashIV uint32 + stack *stack.Stack // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + ids []uint32 + hashIV uint32 + fragmentation *fragmentation.Fragmentation } @@ -523,37 +776,28 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return 0, false, false } - ipHdr := header.IPv4(hdr) - // Header may have options, determine the true header length. - headerLen := int(ipHdr.HeaderLength()) - if headerLen < header.IPv4MinimumSize { - // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in - // order for the packet to be valid. Figure out if we want to reject this - // case. - headerLen = header.IPv4MinimumSize - } - hdr, ok = pkt.NetworkHeader().Consume(headerLen) - if !ok { - return 0, false, false - } - ipHdr = header.IPv4(hdr) + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true +} - // If this is a fragment, don't bother parsing the transport header. - parseTransportHeader := true - if ipHdr.More() || ipHdr.FragmentOffset() != 0 { - parseTransportHeader = false - } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) - return ipHdr.TransportProtocol(), parseTransportHeader, true +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + if v { + atomic.StoreUint32(&p.forwarding, 1) + } else { + atomic.StoreUint32(&p.forwarding, 0) + } } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -577,7 +821,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui } // NewProtocol returns an IPv4 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -588,9 +832,10 @@ func NewProtocol() stack.NetworkProtocol { hashIV := r[buckets] return &protocol{ + stack: s, ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 197e3bc51..33cd5a3eb 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,18 +17,21 @@ package ipv4_test import ( "bytes" "encoding/hex" - "fmt" - "math/rand" + "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -36,8 +39,8 @@ import ( func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) const defaultMTU = 65536 @@ -92,29 +95,244 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeRandPkt generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + defaultMTU = header.IPv6MinimumMTU + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 + minTotalLength uint16 + transportProtocol uint8 + TTL uint8 + shouldFail bool + expectICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + }{ + { + name: "valid", + headerLength: header.IPv4MinimumSize, + minTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: false, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + headerLength: header.IPv4MinimumSize, + minTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + shouldFail: false, + }, + { + name: "one TTL", + headerLength: header.IPv4MinimumSize, + minTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + shouldFail: false, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + minTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (0)", + headerLength: header.IPv4MinimumSize, + minTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip - 1)", + headerLength: header.IPv4MinimumSize, + minTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip + icmp - 1)", + headerLength: header.IPv4MinimumSize, + minTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad protocol", + headerLength: header.IPv4MinimumSize, + minTotalLength: defaultMTU, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: hdrLength + extraLength, - Data: buffer.NewVectorisedView(totalLength, views), - }) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + ipHeaderLength := header.IPv4MinimumSize + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.minTotalLength < totalLen { + totalLen = test.minTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: test.headerLength, + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectICMP { + t.Fatal("expected ICMP error response missing") + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer. + vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) + replyIPHeader := header.IPv4(vv.ToView()) + + // At this stage we only know it's an IP header so verify that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + ) + + // All expected responses are ICMP packets. + if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { + t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + } + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + + // Sanity check the response. + switch replyICMPHeader.Type() { + case header.ICMPv4DstUnreachable: + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + if !test.shouldFail || !test.expectICMP { + t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + return + case header.ICMPv4EchoReply: + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength), + checker.IPFullLength(uint16(requestPkt.Size())), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + checker.ICMPv4Checksum(), + ), + ) + if test.shouldFail { + t.Fatalf("unexpected Echo Reply packet\n") + } + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + } + }) } - return pkt } // comparePayloads compared the contents of all the packets against the contents @@ -186,101 +404,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type errorChannel struct { - *channel.Endpoint - Ch chan *stack.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan *stack.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: ep, - } -} - func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 } fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + extraHeaderReserveLength int + payloadViewsSizes []int + expectedFrags int }{ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, @@ -295,36 +431,29 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("err got %v, want %v", err, nil) + t.Errorf("got err = %s, want = nil", err) } - var results []*stack.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } + if got := len(ep.WrittenPackets); got != ft.expectedFrags { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, results, source, ft.mtu) + compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } } @@ -335,152 +464,376 @@ func TestFragmentationErrors(t *testing.T) { fragTests := []struct { description string mtu uint32 - hdrLength int + transportHeaderLength int payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + err *tcpip.Error + allowPackets int + fragmentCount int }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + { + description: "NoFrag", + mtu: 2000, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 1, + }, + { + description: "ErrorOnFirstFrag", + mtu: 500, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 3, + }, + { + description: "ErrorOnSecondFrag", + mtu: 500, + transportHeaderLength: 0, + payloadViewsSizes: []int{1000}, + err: tcpip.ErrAborted, + allowPackets: 1, + fragmentCount: 3, + }, + { + description: "ErrorOnFirstFragMTUSmallerThanHeader", + mtu: 500, + transportHeaderLength: 1000, + payloadViewsSizes: []int{500}, + err: tcpip.ErrAborted, + allowPackets: 0, + fragmentCount: 4, + }, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + r := buildRoute(t, ep) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, }, pkt) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } + if err != ft.err { + t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) } }) } } func TestInvalidFragments(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 6 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + autoChecksum bool // if true, the Checksum field will be overwritten. + } + // These packets have both IHL and TotalLength set to 0. - testCases := []struct { + tests := []struct { name string - packets [][]byte + fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 }{ { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset non-zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, + FragmentOffset: 59776, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 17 octets and Fragment offset 65520 + // Leading to the fragment end to be past 65536. + name: "fragment ends past 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 17, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(17), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 16 octets and fragment offset 65520 + // Leading to the fragment end to be exactly 65536. + name: "fragment ends exactly at 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(16), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 0, + wantMalformedFragments: 0, }, { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL less than IPv4 minimum size", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 1944, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize - 12, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 2, + wantMalformedFragments: 0, }, { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "fragment with short TotalLength and extra payload", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 28816, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 4, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + name: "multiple fragments with More Fragments flag set to false", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 128, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, }, }) + e := channel.New(0, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + len(f.payload) + hdr := buffer.NewPrependable(pktSize) - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + copy(ip[header.IPv4MinimumSize:], f.payload) + + if f.autoChecksum { + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + } - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + vv := hdr.View().ToVectorisedView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, })) } - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) } }) @@ -534,6 +887,9 @@ func TestReceiveFragments(t *testing.T) { // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] + // Used to test the max reassembled payload length (65,535 octets). + ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) + udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] type fragmentData struct { srcAddr tcpip.Address @@ -827,14 +1183,36 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: nil, }, + { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -906,3 +1284,189 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + // Parameterize the tests to run with both WritePacket and WritePackets. + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + } + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatalf("NewSubnet(_, _) failed: %v", err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bcc64994e..97adbcbd4 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -5,14 +5,18 @@ package(licenses = ["notice"]) go_library( name = "ipv6", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "ndp.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/stack", ], @@ -34,6 +38,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go index d199ded6a..09ba133b1 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go @@ -14,7 +14,7 @@ // Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. -package stack +package ipv6 import "strconv" diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2b83c421e..8e9def6b8 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,6 +15,8 @@ package ipv6 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -39,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -207,14 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now, drop this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is // attempting to perform address resolution on the target. In this case, @@ -227,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // stack know so it can handle such a scenario and do nothing further with // the NS. if r.RemoteAddress == header.IPv6Any { - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } } // Do not handle neighbor solicitations targeted to an address that is @@ -240,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -275,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if e.nud != nil { e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } // ICMPv6 Neighbor Solicit messages are always sent to @@ -318,6 +326,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), }) packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -352,20 +361,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) targetAddr := na.TargetAddress() - s := r.Stack() - - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now short-circuit this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } return } @@ -395,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // address cache with the link address for the target of the message. if len(targetLinkAddr) != 0 { if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) return } @@ -423,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -438,6 +453,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme Data: pkt.Data, }) packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) @@ -477,7 +493,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme stack := r.Stack() // Is the networking stack operating as a router? - if !stack.Forwarding() { + if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. received.RouterOnlyPacketsDroppedByHost.Increment() return @@ -566,9 +582,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } - // Tell the NIC to handle the RA. - stack := r.Stack() - stack.HandleNDPRA(e.nicID, routerAddr, ra) + e.mu.Lock() + e.mu.ndp.handleRA(routerAddr, ra) + e.mu.Unlock() case header.ICMPv6RedirectMsg: // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating @@ -637,6 +653,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, }) icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr.SetType(header.ICMPv6NeighborSolicit) copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr @@ -665,3 +682,159 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return tcpip.LinkAddress([]byte(nil)), false } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv6 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv6 and sends it. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // Only send ICMP error if the address is not a multicast v6 + // address and the source is not the unspecified address. + // + // There are exceptions to this rule. + // See: point e.3) RFC 4443 section-2.4 + // + // (e) An ICMPv6 error message MUST NOT be originated as a result of + // receiving the following: + // + // (e.1) An ICMPv6 error message. + // + // (e.2) An ICMPv6 redirect message [IPv6-DISC]. + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + // + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + return nil + } + + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + + if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { + // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. + // Unfortunately at this time ICMP Packets do not have a transport + // header separated out. It is in the Data part so we need to + // separate it out now. We will just pretend it is a minimal length + // ICMP packet as we don't really care if any later bits of a + // larger ICMP packet are in the header view or in the Data view. + transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) + if !ok { + return nil + } + typ := header.ICMPv6(transport).Type() + if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { + return nil + } + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + available := int(mtu) - headerLen + if available < header.IPv6MinimumSize { + return nil + } + payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload.CapLength(payloadLen) + + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + + icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) + err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) + if err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8112ed051..31370c1d4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -75,7 +75,8 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { + return stack.TransportPacketHandled } type stubLinkAddressCache struct { @@ -102,6 +103,30 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -120,8 +145,8 @@ func TestICMPCounts(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) { @@ -149,9 +174,13 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) @@ -258,8 +287,8 @@ func TestICMPCounts(t *testing.T) { func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) { @@ -287,9 +316,13 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) @@ -423,12 +456,12 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), } @@ -723,12 +756,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: test.useNeighborCache, }) if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -919,7 +952,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -1097,7 +1130,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -1203,7 +1236,10 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := NewProtocol() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index af3cd91c6..c8a3e0b34 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv6 contains the implementation of the ipv6 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv6.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv6 contains the implementation of the ipv6 network protocol. package ipv6 import ( "fmt" + "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -44,14 +42,302 @@ const ( DefaultTTL = 64 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) +var _ stack.NDPEndpoint = (*endpoint)(nil) +var _ NDPEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + + // enabled is set to 1 when the endpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + ndp ndpState + } +} + +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it is passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + +// InvalidateDefaultRouter implements stack.NDPEndpoint. +func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.invalidateDefaultRouter(rtr) +} + +// SetNDPConfigurations implements NDPEndpoint. +func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) { + c.validate() + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.configs = c +} + +// hasTentativeAddr returns true if addr is tentative on e. +func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { + e.mu.RLock() + addressEndpoint := e.getAddressRLocked(addr) + e.mu.RUnlock() + return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative +} + +// dupTentativeAddrDetected attempts to inform e that a tentative addr is a +// duplicate on a link. +// +// dupTentativeAddrDetected removes the tentative address if it exists. If the +// address was generated via SLAAC, an attempt is made to generate a new +// address. +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return tcpip.ErrBadAddress + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + return tcpip.ErrInvalidEndpointState + } + + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + return err + } + + prefix := addressEndpoint.AddressWithPrefix().Subnet() + + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil +} + +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.Enabled() { + return + } + + if forwarding { + // When transitioning into an IPv6 router, host-only state (NDP discovered + // routers, discovered on-link prefixes, and auto-generated addresses) is + // cleaned up/invalidated and NDP router solicitations are stopped. + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(true /* hostOnly */) + } else { + // When transitioning into an IPv6 host, NDP router solicitations are + // started. + e.mu.ndp.startSolicitingRouters() + } +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + // + // Also auto-generate an IPv6 link-local address based on the endpoint's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the endpoint + // was last enabled, other devices may have acquired the same addresses. + var err *tcpip.Error + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + addr := addressEndpoint.AddressWithPrefix().Address + if !header.IsV6UnicastAddress(addr) { + return true + } + + switch addressEndpoint.GetKind() { + case stack.Permanent: + addressEndpoint.SetKind(stack.PermanentTentative) + fallthrough + case stack.PermanentTentative: + err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint) + return err == nil + default: + return true + } + }) + if err != nil { + return err + } + + // Do not auto-generate an IPv6 link-local address for loopback devices. + if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) + } + + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyway. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !e.protocol.Forwarding() { + e.mu.ndp.startSolicitingRouters() + } + + return nil +} + +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() +} + +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(false /* hostOnly */) + e.stopDADForPermanentAddressesLocked() + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) + } +} + +// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) stopDADForPermanentAddressesLocked() { + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr) + } + + return true + }) } // DefaultTTL is the default hop limit for this endpoint. @@ -65,16 +351,6 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -107,6 +383,32 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() @@ -121,8 +423,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) + return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -138,9 +444,54 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } + return n, err + } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err + } + n++ + } + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet @@ -153,12 +504,24 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 4291 section 2.7: + // Multicast addresses must not be used as source addresses in IPv6 + // packets or appear in any Routing header. + if header.IsV6MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -169,7 +532,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false - for firstHeader := true; ; firstHeader = false { + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() + return + } + + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -183,11 +558,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -209,13 +584,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -226,16 +613,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } @@ -268,7 +659,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() return } if done { @@ -311,12 +702,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - last := start + uint16(fragmentPayloadLen) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < start { + // Drop the fragment if the size of the reassembled payload would exceed + // the maximum payload size. + if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return @@ -333,7 +722,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ID: extHdr.ID(), }, start, - last, + start+uint16(fragmentPayloadLen)-1, extHdr.More(), uint8(rawPayload.Identifier), rawPayload.Buf, @@ -372,13 +761,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -397,21 +798,55 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf + r.Stats().IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. - e.dispatcher.DeliverTransportPacket(r, p, pkt) + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC 4443 section 3.1: + // A destination node SHOULD originate a Destination Unreachable + // message with Code 4 in response to a packet for which the + // transport protocol (e.g., UDP) has no listener, if that transport + // protocol has no alternative means to inform the sender. + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -419,19 +854,340 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } // Close cleans up resources associated with the endpoint. -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + e.disableLocked() + e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) + e.stopDADForPermanentAddressesLocked() + e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() + + e.protocol.forgetEndpoint(e) +} // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + // TODO(b/169350103): add checks here after making sure we no longer receive + // an empty address. + e.mu.Lock() + defer e.mu.Unlock() + return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) +} + +// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but +// with locking requirements. +// +// addAndAcquirePermanentAddressLocked also joins the passed address's +// solicited-node multicast group and start duplicate address detection. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err != nil { + return nil, err + } + + if !header.IsV6UnicastAddress(addr.Address) { + return addressEndpoint, nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + return nil, err + } + + addressEndpoint.SetKind(stack.PermanentTentative) + + if e.Enabled() { + if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil { + return nil, err + } + } + + return addressEndpoint, nil +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + return e.removePermanentEndpointLocked(addressEndpoint, true) +} + +// removePermanentEndpointLocked is like removePermanentAddressLocked except +// it works with a stack.AddressEndpoint. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { + addr := addressEndpoint.AddressWithPrefix() + unicast := header.IsV6UnicastAddress(addr.Address) + if unicast { + e.mu.ndp.stopDuplicateAddressDetection(addr.Address) + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + switch addressEndpoint.ConfigType() { + case stack.AddressConfigSlaac: + e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case stack.AddressConfigSlaacTemp: + e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + } + } + + if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { + return err + } + + if !unicast { + return nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil +} + +// hasPermanentAddressLocked returns true if the endpoint has a permanent +// address equal to the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return false + } + return addressEndpoint.GetKind().IsPermanent() +} + +// getAddressRLocked returns the endpoint for the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { + return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) +} + +// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with +// locking requirements. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB) +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + // addrCandidate is a candidate for Source Address Selection, as per + // RFC 6724 section 5. + type addrCandidate struct { + addressEndpoint stack.AddressEndpoint + scope header.IPv6AddressScope + } + + if len(remoteAddr) == 0 { + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + var cs []addrCandidate + e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !addressEndpoint.IsAssigned(allowExpired) { + return + } + + addr := addressEndpoint.AddressWithPrefix().Address + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, addrCandidate{ + addressEndpoint: addressEndpoint, + scope: scope, + }) + }) + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return true + } + if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if c.addressEndpoint.IncRef() { + return c.addressEndpoint + } + } + + return nil +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV6MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) + type protocol struct { + stack *stack.Stack + + mu struct { + sync.RWMutex + + eps map[*endpoint]struct{} + } + // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. - defaultTTL uint32 + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. + defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + fragmentation *fragmentation.Fragmentation + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. + ndpConfigs NDPConfigurations + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + + // autoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -456,16 +1212,36 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, protocol: p, - stack: st, } + e.mu.addressableEndpointState.Init(e) + e.mu.ndp = ndpState{ + ep: e, + configs: p.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + } + e.mu.ndp.initializeTempAddrState() + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.eps[e] = struct{}{} + return e +} + +func (p *protocol) forgetEndpoint(e *endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, e) } // SetOption implements NetworkProtocol.SetOption. @@ -506,75 +1282,43 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return 0, false, false } - ipHdr := header.IPv6(hdr) - // dataClone consists of: - // - Any IPv6 header bytes after the first 40 (i.e. extensions). - // - The transport header, if present. - // - Any other payload data. - views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + return proto, !fragMore && fragOffset == 0, true +} - // Iterate over the IPv6 extensions to find their length. - // - // Parsing occurs again in HandlePacket because we don't track the - // extensions in PacketBuffer. Unfortunately, that means HandlePacket - // has to do the parsing work again. - var nextHdr tcpip.TransportProtocolNumber - foundNext := true - extensionsSize := 0 -traverseExtensions: - for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { - if err != nil { - break - } - // If we exhaust the extension list, the entire packet is the IPv6 header - // and (possibly) extensions. - if done { - extensionsSize = dataClone.Size() - foundNext = false - break - } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} - switch extHdr := extHdr.(type) { - case header.IPv6FragmentExtHdr: - // If this is an atomic fragment, we don't have to treat it specially. - if !extHdr.More() && extHdr.FragmentOffset() == 0 { - continue - } - // This is a non-atomic fragment and has to be re-assembled before we can - // examine the payload for a transport header. - foundNext = false +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.SwapUint32(&p.forwarding, 1) == 0 + } + return atomic.SwapUint32(&p.forwarding, 0) == 1 +} - case header.IPv6RawPayloadHeader: - // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() - nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) - break traverseExtensions +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + p.mu.Lock() + defer p.mu.Unlock() - default: - // Any other extension is a no-op, keep looping until we find the payload. - } + if !p.setForwarding(v) { + return } - // Put the IPv6 header with extensions in pkt.NetworkHeader(). - hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) - if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + for ep := range p.mu.eps { + ep.transitionForwarding(v) } - ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - - return nextHdr, foundNext, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -587,10 +1331,69 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -// NewProtocol returns an IPv6 network protocol. -func NewProtocol() stack.NetworkProtocol { - return &protocol{ - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), +// Options holds options to configure a new protocol. +type Options struct { + // NDPConfigs is the default NDP configurations used by interfaces. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // + // Note, setting this to true does not mean that a link-local address is + // assigned right away, or at all. If Duplicate Address Detection is enabled, + // an address is only assigned if it successfully resolves. If it fails, no + // further attempts are made to auto-generate an IPv6 link-local adddress. + // + // The generated link-local address follows RFC 4291 Appendix A guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte +} + +// NewProtocolWithOptions returns an IPv6 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { + opts.NDPConfigs.validate() + + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + + ndpDisp: opts.NDPDisp, + ndpConfigs: opts.NDPConfigs, + opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + } + p.mu.eps = make(map[*endpoint]struct{}) + p.SetDefaultTTL(DefaultTTL) + return p } } + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 54787198f..d7f82973b 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,13 +15,16 @@ package ipv6 import ( + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -139,18 +142,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -172,11 +175,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) @@ -184,8 +187,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -271,7 +274,7 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -299,11 +302,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -334,9 +344,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -346,12 +357,58 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + multicast: true, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -362,39 +419,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -410,6 +505,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -425,9 +521,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "destination with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (muilticast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -437,12 +554,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -453,22 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -502,12 +636,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -551,6 +715,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -571,16 +736,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -588,6 +754,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -629,12 +803,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -648,6 +826,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -687,6 +903,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. @@ -731,6 +948,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) + var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte + udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] + ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) + tests := []struct { name string expectedPayload []byte @@ -1020,6 +1241,44 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: nil, }, { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:65520], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[65520:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -1504,8 +1763,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1572,3 +1831,308 @@ func TestReceiveIPv6Fragments(t *testing.T) { }) } } + +func TestInvalidIPv6Fragments(t *testing.T) { + const ( + nicID = 1 + fragmentExtHdrLen = 8 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + tests := []struct { + name string + fragments []fragmentData + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + name: "fragments reassembled into a payload exceeding the max IPv6 payload size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, + []buffer.View{ + // Fragment extension header. + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, + ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, + 0, 0, 0, 1}), + // Payload length = 16 + payloadGen(16), + }, + ), + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + }) + e := channel.New(0, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { + t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { + t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) + } + }) + } +} + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + if err := s.AddAddress(1, ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + } + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + if err != nil { + t.Fatalf("NewSubnet(_, _) failed: %v", err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if !hasEP { + t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) + } + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index b0873d1af..48a4c65e3 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package ipv6 import ( "fmt" @@ -23,9 +23,27 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor // Solicitation messages to send when doing Duplicate Address Detection // for a tentative address. @@ -34,7 +52,7 @@ const ( defaultDupAddrDetectTransmits = 1 // defaultMaxRtrSolicitations is the default number of Router - // Solicitation messages to send when a NIC becomes enabled. + // Solicitation messages to send when an IPv6 endpoint becomes enabled. // // Default = 3 (from RFC 4861 section 10). defaultMaxRtrSolicitations = 3 @@ -131,7 +149,7 @@ const ( minRegenAdvanceDuration = time.Duration(0) // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt - // SLAAC address regenerations in response to a NIC-local conflict. + // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 ) @@ -163,7 +181,7 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be preferred for at + // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -173,11 +191,17 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be valid for at least + // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour ) +// NDPEndpoint is an endpoint that supports NDP. +type NDPEndpoint interface { + // SetNDPConfigurations sets the NDP configurations. + SetNDPConfigurations(NDPConfigurations) +} + // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an // NDP Router Advertisement informed the Stack about. type DHCPv6ConfigurationFromNDPRA int @@ -192,7 +216,7 @@ const ( // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 - // will return all available configuration information. + // returns all available configuration information when serving addresses. DHCPv6ManagedAddress // DHCPv6OtherConfigurations indicates that other configuration information is @@ -207,19 +231,18 @@ const ( // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicID) completes. resolved - // will be set to true if DAD completed successfully (no duplicate addr - // detected); false otherwise (addr was detected to be a duplicate on - // the link the NIC is a part of, or it was stopped for some other - // reason, such as the address being removed). If an error occured - // during DAD, err will be set and resolved must be ignored. + // OnDuplicateAddressDetectionStatus is called when the DAD process for an + // address (addr) on a NIC (with ID nicID) completes. resolved is set to true + // if DAD completed successfully (no duplicate addr detected); false otherwise + // (addr was detected to be a duplicate on the link the NIC is a part of, or + // it was stopped for some other reason, such as the address being removed). + // If an error occured during DAD, err is set and resolved must be ignored. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) - // OnDefaultRouterDiscovered will be called when a new default router is + // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered // router should be remembered. // @@ -227,56 +250,55 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool - // OnDefaultRouterInvalidated will be called when a discovered default - // router that was remembered is invalidated. + // OnDefaultRouterInvalidated is called when a discovered default router that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) - // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true if the newly discovered - // on-link prefix should be remembered. + // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. + // Implementations must return true if the newly discovered on-link prefix + // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool - // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix that was remembered is invalidated. + // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) - // OnAutoGenAddress will be called when a new prefix with its - // autonomous address-configuration flag set has been received and SLAAC - // has been performed. Implementations may prevent the stack from - // assigning the address to the NIC by returning false. + // OnAutoGenAddress is called when a new prefix with its autonomous address- + // configuration flag set is received and SLAAC was performed. Implementations + // may prevent the stack from assigning the address to the NIC by returning + // false. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool - // OnAutoGenAddressDeprecated will be called when an auto-generated - // address (as part of SLAAC) has been deprecated, but is still - // considered valid. Note, if an address is invalidated at the same - // time it is deprecated, the deprecation event MAY be omitted. + // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) + // is deprecated, but is still considered valid. Note, if an address is + // invalidated at the same ime it is deprecated, the deprecation event may not + // be received. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnAutoGenAddressInvalidated will be called when an auto-generated - // address (as part of SLAAC) has been invalidated. + // OnAutoGenAddressInvalidated is called when an auto-generated address + // (SLAAC) is invalidated. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnRecursiveDNSServerOption will be called when an NDP option with - // recursive DNS servers has been received. Note, addrs may contain - // link-local addresses. + // OnRecursiveDNSServerOption is called when the stack learns of DNS servers + // through NDP. Note, the addresses may contain link-local addresses. // // It is up to the caller to use the DNS Servers only for their valid // lifetime. OnRecursiveDNSServerOption may be called for new or @@ -288,8 +310,8 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) - // OnDNSSearchListOption will be called when an NDP option with a DNS - // search list has been received. + // OnDNSSearchListOption is called when the stack learns of DNS search lists + // through NDP. // // It is up to the caller to use the domain names in the search list // for only their valid lifetime. OnDNSSearchListOption may be called @@ -298,8 +320,8 @@ type NDPDispatcher interface { // be increased, decreased or completely invalidated when lifetime = 0. OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) - // OnDHCPv6Configuration will be called with an updated configuration that is - // available via DHCPv6 for a specified NIC. + // OnDHCPv6Configuration is called with an updated configuration that is + // available via DHCPv6 for the passed NIC. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. @@ -320,7 +342,7 @@ type NDPConfigurations struct { // Must be greater than or equal to 1ms. RetransmitTimer time.Duration - // The number of Router Solicitation messages to send when the NIC + // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. MaxRtrSolicitations uint8 @@ -335,24 +357,22 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements will be - // processed. + // HandleRAs determines whether or not Router Advertisements are processed. HandleRAs bool - // DiscoverDefaultRouters determines whether or not default routers will - // be discovered from Router Advertisements. This configuration is - // ignored if HandleRAs is false. + // DiscoverDefaultRouters determines whether or not default routers are + // discovered from Router Advertisements, as per RFC 4861 section 6. This + // configuration is ignored if HandleRAs is false. DiscoverDefaultRouters bool - // DiscoverOnLinkPrefixes determines whether or not on-link prefixes - // will be discovered from Router Advertisements' Prefix Information - // option. This configuration is ignored if HandleRAs is false. + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are + // discovered from Router Advertisements' Prefix Information option, as per + // RFC 4861 section 6. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool - // AutoGenGlobalAddresses determines whether or not global IPv6 - // addresses will be generated for a NIC in response to receiving a new - // Prefix Information option with its Autonomous Address - // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs + // SLAAC to auto-generate global SLAAC addresses in response to Prefix + // Information options, as per RFC 4862. // // Note, if an address was already generated for some unique prefix, as // part of SLAAC, this option does not affect whether or not the @@ -366,12 +386,12 @@ type NDPConfigurations struct { // // If the method used to generate the address does not support creating // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's - // MAC address), then no attempt will be made to resolve the conflict. + // MAC address), then no attempt is made to resolve the conflict. AutoGenAddressConflictRetries uint8 // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC - // addresses will be generated for a NIC as part of SLAAC privacy extensions, - // RFC 4941. + // addresses are generated for an IPv6 endpoint as part of SLAAC privacy + // extensions, as per RFC 4941. // // Ignored if AutoGenGlobalAddresses is false. AutoGenTempGlobalAddresses bool @@ -410,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations { } // validate modifies an NDPConfigurations with valid values. If invalid values -// are present in c, the corresponding default values will be used instead. +// are present in c, the corresponding default values are used instead. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -439,8 +459,8 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { - // The NIC this ndpState is for. - nic *NIC + // The IPv6 endpoint this ndpState is for. + ep *endpoint // configs is the per-interface NDP configurations. configs NDPConfigurations @@ -458,8 +478,8 @@ type ndpState struct { // Used to let the Router Solicitation timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this ndpState is associated with. MUST be set when the timer is - // set. + // the IPv6 endpoint this ndpState is associated with. MUST be set when the + // timer is set. done *bool } @@ -492,7 +512,7 @@ type dadState struct { // Used to let the DAD timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this dadState is associated with. + // the IPv6 endpoint this dadState is associated with. done *bool } @@ -537,7 +557,7 @@ type tempSLAACAddrState struct { // The address's endpoint. // // Must not be nil. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint // Has a new temporary SLAAC address already been regenerated? regenerated bool @@ -567,10 +587,10 @@ type slaacPrefixState struct { // // May only be nil when the address is being (re-)generated. Otherwise, // must not be nil as all SLAAC prefixes must have a stable address. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint - // The number of times an address has been generated locally where the NIC - // already had the generated address. + // The number of times an address has been generated locally where the IPv6 + // endpoint already had the generated address. localGenerationFailures uint8 } @@ -578,11 +598,12 @@ type slaacPrefixState struct { tempAddrs map[tcpip.Address]tempSLAACAddrState // The next two fields are used by both stable and temporary addresses - // generated for a SLAAC prefix. This is safe as only 1 address will be - // in the generation and DAD process at any time. That is, no two addresses - // will be generated at the same time for a given SLAAC prefix. + // generated for a SLAAC prefix. This is safe as only 1 address is in the + // generation and DAD process at any time. That is, no two addresses are + // generated at the same time for a given SLAAC prefix. - // The number of times an address has been generated and added to the NIC. + // The number of times an address has been generated and added to the IPv6 + // endpoint. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 @@ -597,16 +618,16 @@ type slaacPrefixState struct { // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -617,18 +638,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // existed, we would get an error since we attempted to add a duplicate // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. - // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + // See endpoint.addAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } return nil @@ -637,25 +658,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var done bool var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the NIC's lock. This is effectively the same - // as starting a goroutine but we use a timer that fires immediately so we can - // reset it for the next DAD iteration. - timer = ndp.nic.stack.Clock().AfterFunc(0, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { + ndp.ep.mu.Lock() + defer ndp.ep.mu.Unlock() if done { // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the NIC lock and stopped DAD - // before this function obtained the NIC lock. Simply return here and do - // nothing further. + // another goroutine already obtained the IPv6 endpoint lock and stopped + // DAD before this function obtained the NIC lock. Simply return here and + // do nothing further. return } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } dadDone := remaining == 0 @@ -663,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var err *tcpip.Error if !dadDone { // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) // Do not hold the lock when sending packets which may be a long running // task or may block link address resolution. We know this is safe // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the NIC. Note, DAD would be - // stopped if the NIC was disabled or removed, or if the address was - // removed. - ndp.nic.mu.Unlock() - err = ndp.sendDADPacket(addr, ref) - ndp.nic.mu.Lock() + // has been stopped before doing any work with the IPv6 endpoint. Note, + // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if + // the address was removed. + ndp.ep.mu.Unlock() + err = ndp.sendDADPacket(addr, addressEndpoint) + ndp.ep.mu.Lock() + addressEndpoint.DecRef() } if done { // If we reach this point, it means that DAD was stopped after we released - // the NIC's read lock and before we obtained the write lock. + // the IPv6 endpoint's read lock and before we obtained the write lock. return } if dadDone { // DAD has resolved. - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) } else if err == nil { // DAD is not done and we had no errors when sending the last NDP NS, // schedule the next DAD timer. remaining-- - timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + timer.Reset(ndp.configs.RetransmitTimer) return } @@ -698,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // integrator know DAD has completed. delete(ndp.dad, addr) - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } // If DAD resolved for a stable SLAAC address, attempt generation of a // temporary SLAAC address. - if dadDone && ref.configType == slaac { + if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) } }) @@ -722,28 +744,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns // addr. // -// addr must be a tentative IPv6 address on ndp's NIC. +// addr must be a tentative IPv6 address on ndp's IPv6 endpoint. // -// The NIC ndp belongs to MUST NOT be locked. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is not + // locked, the NIC could have been disabled or removed by another goroutine. if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { return err } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) } icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) @@ -752,17 +777,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd ns.SetTargetAddress(addr) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - pkt := NewPacketBuffer(PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, }, pkt, ); err != nil { sent.Dropped.Increment() @@ -778,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd // such a state forever, unless some other external event resolves the DAD // process (receiving an NA from the true owner of addr, or an NS for addr // (implying another node is attempting to use addr)). It is up to the caller -// of this function to handle such a scenario. Normally, addr will be removed -// from n right after this function returns or the address successfully -// resolved. +// of this function to handle such a scenario. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { dad, ok := ndp.dad[addr] if !ok { @@ -801,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - // Is the NIC configured to handle RAs at all? + // Is the IPv6 endpoint configured to handle RAs at all? // // Currently, the stack does not determine router interface status on a - // per-interface basis; it is a stack-wide configuration, so we check - // stack's forwarding flag to determine if the NIC is a routing - // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + // per-interface basis; it is a protocol-wide configuration, so we check the + // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding + // packets. + if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { return } // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -839,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if ndp.dhcpv6Configuration != configuration { ndp.dhcpv6Configuration = configuration - ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration) } } - // Is the NIC configured to discover default routers? + // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] rl := ra.RouterLifetime() @@ -881,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -928,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -942,32 +964,32 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } // rememberDefaultRouter remembers a newly discovered default router with IPv6 // link-local address ip with lifetime rl. // -// The router identified by ip MUST NOT already be known by the NIC. +// The router identified by ip MUST NOT already be known by the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return } state := defaultRouterState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateDefaultRouter(ip) }), } @@ -982,22 +1004,22 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return } state := onLinkPrefixState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } @@ -1011,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -1025,8 +1047,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1036,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { // handleOnLinkPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the on-link flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { prefix := pi.Subnet() prefixState, ok := ndp.onLinkPrefixes[prefix] @@ -1089,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -1125,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -1142,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.stableAddr.ref) + ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint) }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1189,7 +1211,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddr.ref.getKind() == permanent { + if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) @@ -1198,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.slaacPrefixes[prefix] = state } -// addSLAACAddr adds a SLAAC address to the NIC. +// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return nil } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { // Informed by the integrator not to add the address. return nil } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr, - } - - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { - panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } - return ref + return addressEndpoint } // generateSLAACAddr generates a SLAAC address for prefix. @@ -1232,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo // // Panics if the prefix is not a SLAAC prefix or it already has an address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.stableAddr.ref; r != nil { - panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix())) } // If we have already reached the maximum address generation attempts for the @@ -1255,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()), dadCounter, oIID.SecretKey, ) @@ -1272,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() + linkAddr := ndp.ep.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } @@ -1291,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt PrefixLen: validPrefixLenForAutoGen, } - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } state.stableAddr.localGenerationFailures++ } - if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { - state.stableAddr.ref = ref + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true } @@ -1309,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // regenerateSLAACAddr regenerates an address for a SLAAC prefix. // -// If generating a new address for the prefix fails, the prefix will be -// invalidated. +// If generating a new address for the prefix fails, the prefix is invalidated. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1332,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { // generateTempSLAACAddr generates a new temporary SLAAC address. // -// If resetGenAttempts is true, the prefix's generation counter will be reset. +// If resetGenAttempts is true, the prefix's generation counter is reset. // // Returns true if a new address was generated. func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { @@ -1353,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - stableAddr := prefixState.stableAddr.ref.address() + stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address now := time.Now() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary @@ -1392,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - // Attempt to generate a new address that is not already assigned to the NIC. + // Attempt to generate a new address that is not already assigned to the IPv6 + // endpoint. var generatedAddr tcpip.AddressWithPrefix for i := 0; ; i++ { // If we were unable to generate an address after the maximum SLAAC address @@ -1402,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } } @@ -1410,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary // address with a zero preferred lifetime. The checks above ensure this // so we know the address is not deprecated. - ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) - if ref == nil { + addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */) + if addressEndpoint == nil { return false } state := tempSLAACAddrState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1427,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) } - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1442,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1465,8 +1482,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla prefixState.tempAddrs[generatedAddr.Address] = tempAddrState ndp.slaacPrefixes[prefix] = prefixState }), - createdAt: now, - ref: ref, + createdAt: now, + addressEndpoint: addressEndpoint, } state.deprecationJob.Schedule(pl) @@ -1481,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1496,14 +1513,14 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint) } else { - prefixState.stableAddr.ref.deprecated = false + prefixState.stableAddr.addressEndpoint.SetDeprecated(false) } // If prefix was preferred for some finite lifetime before, cancel the @@ -1565,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // If DAD is not yet complete on the stable address, there is no need to do // work with temporary addresses. - if prefixState.stableAddr.ref.getKind() != permanent { + if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent { return } @@ -1608,9 +1625,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat newPreferredLifetime := preferredUntil.Sub(now) tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { - tempAddrState.ref.deprecated = false + tempAddrState.addressEndpoint.SetDeprecated(false) tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } @@ -1635,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // due to an update in preferred lifetime. // // If each temporay address has already been regenerated, no new temporary - // address will be generated. To ensure continuation of temporary SLAAC - // addresses, we manually try to regenerate an address here. + // address is generated. To ensure continuation of temporary SLAAC addresses, + // we manually try to regenerate an address here. if len(regenForAddr) != 0 || allAddressesRegenerated { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. @@ -1647,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } } -// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP -// dispatcher that ref has been deprecated. +// deprecateSLAACAddress marks the address as deprecated and notifies the NDP +// dispatcher that address has been deprecated. // -// deprecateSLAACAddress does nothing if ref is already deprecated. +// deprecateSLAACAddress does nothing if the address is already deprecated. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { - if ref.deprecated { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) { + if addressEndpoint.Deprecated() { return } - ref.deprecated = true - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) + addressEndpoint.SetDeprecated(true) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } // invalidateSLAACPrefix invalidates a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.stableAddr.ref; r != nil { + ndp.cleanupSLAACPrefixResources(prefix, state) + + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } - - ndp.cleanupSLAACPrefixResources(prefix, state) } // cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's // resources. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() { + if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.stableAddr.ref = nil + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil ndp.slaacPrefixes[prefix] = state return } @@ -1709,14 +1727,17 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr // // Panics if the SLAAC prefix is not known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { // Invalidate all temporary addresses. for tempAddr, tempAddrState := range state.tempAddrs { ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) } - state.stableAddr.ref = nil + if state.stableAddr.addressEndpoint != nil { + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil + } state.deprecationJob.Cancel() state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) @@ -1724,12 +1745,12 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa // invalidateTempSLAACAddr invalidates a temporary SLAAC address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { // Since we are already invalidating the address, do not invalidate the // address when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) @@ -1738,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary // SLAAC address's resources from ndp. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } if !invalidateAddr { @@ -1765,35 +1786,29 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's // jobs and entry. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.addressEndpoint.DecRef() + tempAddrState.addressEndpoint = nil tempAddrState.deprecationJob.Cancel() tempAddrState.invalidationJob.Cancel() tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } -// cleanupState cleans up ndp's state. -// -// If hostOnly is true, then only host-specific state will be cleaned up. +// removeSLAACAddresses removes all SLAAC addresses. // -// cleanupState MUST be called with hostOnly set to true when ndp's NIC is -// transitioning from a host to a router. This function will invalidate all -// discovered on-link prefixes, discovered routers, and auto-generated -// addresses. -// -// If hostOnly is true, then the link-local auto-generated address will not be -// invalidated as routers are also expected to generate a link-local address. +// If keepLinkLocal is false, the SLAAC generated link-local address is removed. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalPrefixes := 0 + var linkLocalPrefixes int for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && prefix == linkLocalSubnet { + if keepLinkLocal && prefix == linkLocalSubnet { linkLocalPrefixes++ continue } @@ -1804,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) } +} + +// cleanupState cleans up ndp's state. +// +// If hostOnly is true, then only host-specific state is cleaned up. +// +// This function invalidates all discovered on-link prefixes, discovered +// routers, and auto-generated addresses. +// +// If hostOnly is true, then the link-local auto-generated address aren't +// invalidated as routers are also expected to generate a link-local address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupState(hostOnly bool) { + ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1827,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. @@ -1848,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { - ndp.nic.mu.Lock() + ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { + ndp.ep.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the NIC lock and stopped solicitations. - // Simply return here and do nothing further. - ndp.nic.mu.Unlock() + // goroutine already obtained the IPv6 endpoint lock and stopped + // solicitations. Simply return here and do nothing further. + ndp.ep.mu.Unlock() return } // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) - if ref == nil { - ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) + if addressEndpoint == nil { + // Incase this ends up creating a new temporary address, we need to hold + // onto the endpoint until a route is obtained. If we decrement the + // reference count before obtaing a route, the address's resources would + // be released and attempting to obtain a route after would fail. Once a + // route is obtainted, it is safe to decrement the reference count since + // obtaining a route increments the address's reference count. + addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() - localAddr := ref.address() - r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + localAddr := addressEndpoint.AddressWithPrefix().Address + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) + addressEndpoint.DecRef() + if err != nil { + return + } defer r.Release() // Route should resolve immediately since @@ -1876,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() { // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is + // not locked, the IPv6 endpoint could have been disabled or removed by + // another goroutine. if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { return } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1907,21 +1948,20 @@ func (ndp *ndpState) startSolicitingRouters() { rs.Options().Serialize(optsSerializer) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - pkt := NewPacketBuffer(PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, }, pkt, ); err != nil { sent.Dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { @@ -1929,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.nic.mu.Lock() + ndp.ep.mu.Lock() if done || remaining == 0 { ndp.rtrSolicit.timer = nil ndp.rtrSolicit.done = nil } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but - // we still reached this point, then we know the NIC + // we still reached this point, then we know the IPv6 endpoint // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() }) } @@ -1949,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() { // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { if ndp.rtrSolicit.timer == nil { // Nothing to do. @@ -1965,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() { // initializeTempAddrState initializes state related to temporary SLAAC // addresses. func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 480c495fa..25464a03a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -17,6 +17,7 @@ package ipv6 import ( "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -35,8 +36,8 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: useNeighborCache, }) @@ -65,10 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + t.Cleanup(ep.Close) + return s, ep } +var _ NDPDispatcher = (*testNDPDispatcher)(nil) + +// testNDPDispatcher is an NDPDispatcher only allows default router discovery. +type testNDPDispatcher struct { + addr tcpip.Address +} + +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { + t.addr = addr + return true +} + +func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { + t.addr = addr +} + +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { +} + +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return false +} + +func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { +} + +func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { +} + +func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { +} + +func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { + var ndpDisp testNDPDispatcher + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ + NDPDisp: &ndpDisp, + })}, + }) + + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + + ipv6EP := ep.(*endpoint) + ipv6EP.mu.Lock() + ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.Unlock() + + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } + + ndpDisp.addr = "" + ndpEP := ep.(stack.NDPEndpoint) + ndpEP.InvalidateDefaultRouter(lladdr1) + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -98,7 +183,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -202,7 +287,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -475,7 +560,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) e := channel.New(1, 1280, nicLinkAddr) @@ -596,7 +681,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -707,7 +792,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -958,7 +1043,7 @@ func TestNDPValidation(t *testing.T) { if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } stats := s.Stats().ICMP.V6PacketsReceived @@ -1172,7 +1257,7 @@ func TestRouterAdvertValidation(t *testing.T) { e := channel.New(10, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD new file mode 100644 index 000000000..c9e57dc0d --- /dev/null +++ b/pkg/tcpip/network/testutil/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + ], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go new file mode 100644 index 000000000..7cc52985e --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil.go @@ -0,0 +1,144 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. + WrittenPackets []*stack.PacketBuffer + + mtu uint32 + err *tcpip.Error + allowPackets int +} + +// NewMockLinkEndpoint creates a new MockLinkEndpoint. +// +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var n int + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ + } + + return n, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 91fc26722..51d428049 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -127,8 +127,8 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 9e37cab18..8e0ee1cd7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -112,8 +112,8 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) @@ -188,7 +188,7 @@ func main() { defer wq.EventUnregister(&waitEntry) for { - n, wq, err := ep.Accept() + n, wq, err := ep.Accept(nil) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 900938dd1..2eaeab779 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,8 +54,8 @@ go_template_instance( go_library( name = "stack", srcs = [ + "addressable_endpoint_state.go", "conntrack.go", - "dhcpv6configurationfromndpra_string.go", "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", @@ -65,7 +65,6 @@ go_library( "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", - "ndp.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -106,6 +105,7 @@ go_test( name = "stack_x_test", size = "medium", srcs = [ + "addressable_endpoint_state_test.go", "ndp_test.go", "nud_test.go", "stack_test.go", @@ -116,6 +116,7 @@ go_test( deps = [ ":stack", "//pkg/rand", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", @@ -138,7 +139,6 @@ go_test( name = "stack_test", size = "small", srcs = [ - "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", @@ -152,8 +152,8 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "@com_github_dpjacques_clockwork//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go new file mode 100644 index 000000000..db8ac1c2b --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -0,0 +1,758 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) +var _ AddressableEndpoint = (*AddressableEndpointState)(nil) + +// AddressableEndpointState is an implementation of an AddressableEndpoint. +type AddressableEndpointState struct { + networkEndpoint NetworkEndpoint + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + endpoints map[tcpip.Address]*addressState + primary []*addressState + + // groups holds the mapping between group addresses and the number of times + // they have been joined. + groups map[tcpip.Address]uint32 + } +} + +// Init initializes the AddressableEndpointState with networkEndpoint. +// +// Must be called before calling any other function on m. +func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { + a.networkEndpoint = networkEndpoint + + a.mu.Lock() + defer a.mu.Unlock() + a.mu.endpoints = make(map[tcpip.Address]*addressState) + a.mu.groups = make(map[tcpip.Address]uint32) +} + +// ReadOnlyAddressableEndpointState provides read-only access to an +// AddressableEndpointState. +type ReadOnlyAddressableEndpointState struct { + inner *AddressableEndpointState +} + +// AddrOrMatching returns an endpoint for the passed address that is consisdered +// bound to the wrapped AddressableEndpointState. +// +// If addr is an exact match with an existing address, that address is returned. +// Otherwise, f is called with each address and the address that f returns true +// for is returned. +// +// Returns nil of no address matches. +func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + if ep, ok := m.inner.mu.endpoints[addr]; ok { + if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { + return ep + } + } + + for _, ep := range m.inner.mu.endpoints { + if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { + return ep + } + } + + return nil +} + +// Lookup returns the AddressEndpoint for the passed address. +// +// Returns nil if the passed address is not associated with the +// AddressableEndpointState. +func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + ep, ok := m.inner.mu.endpoints[addr] + if !ok { + return nil + } + return ep +} + +// ForEach calls f for each address pair. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + for _, ep := range m.inner.mu.endpoints { + if !f(ep) { + return + } + } +} + +// ForEachPrimaryEndpoint calls f for each primary address. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + for _, ep := range m.inner.mu.primary { + f(ep) + } +} + +// ReadOnly returns a readonly reference to a. +func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { + return ReadOnlyAddressableEndpointState{inner: a} +} + +func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.releaseAddressStateLocked(addrState) +} + +// releaseAddressState removes addrState from s's address state (primary and endpoints list). +// +// Preconditions: a.mu must be write locked. +func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) { + oldPrimary := a.mu.primary + for i, s := range a.mu.primary { + if s == addrState { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + oldPrimary[len(oldPrimary)-1] = nil + break + } + } + delete(a.mu.endpoints, addrState.addr.Address) +} + +// AddAndAcquirePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// AddAndAcquireTemporaryAddress adds a temporary address. +// +// Returns tcpip.ErrDuplicateAddress if the address exists. +// +// The temporary address's endpoint is acquired and returned. +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// addAndAcquireAddressLocked adds, acquires and returns a permanent or +// temporary address. +// +// If the addressable endpoint already has the address in a non-permanent state, +// and addAndAcquireAddressLocked is adding a permanent address, that address is +// promoted in place and its properties set to the properties provided. If the +// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// returned, regardless the kind of address that is being added. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { + // attemptAddToPrimary is false when the address is already in the primary + // address list. + attemptAddToPrimary := true + addrState, ok := a.mu.endpoints[addr.Address] + if ok { + if !permanent { + // We are adding a non-permanent address but the address exists. No need + // to go any further since we can only promote existing temporary/expired + // addresses to permanent. + return nil, tcpip.ErrDuplicateAddress + } + + addrState.mu.Lock() + if addrState.mu.kind.IsPermanent() { + addrState.mu.Unlock() + // We are adding a permanent address but a permanent address already + // exists. + return nil, tcpip.ErrDuplicateAddress + } + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr)) + } + + // We now promote the address. + for i, s := range a.mu.primary { + if s == addrState { + switch peb { + case CanBePrimaryEndpoint: + // The address is already in the primary address list. + attemptAddToPrimary = false + case FirstPrimaryEndpoint: + if i == 0 { + // The address is already first in the primary address list. + attemptAddToPrimary = false + } else { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + } + case NeverPrimaryEndpoint: + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + break + } + } + } + + if addrState == nil { + addrState = &addressState{ + addressableEndpointState: a, + addr: addr, + } + a.mu.endpoints[addr.Address] = addrState + addrState.mu.Lock() + // We never promote an address to temporary - it can only be added as such. + // If we are actaully adding a permanent address, it is promoted below. + addrState.mu.kind = Temporary + } + + // At this point we have an address we are either promoting from an expired or + // temporary address to permanent, promoting an expired address to temporary, + // or we are adding a new temporary or permanent address. + // + // The address MUST be write locked at this point. + defer addrState.mu.Unlock() + + if permanent { + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr)) + } + + // Primary addresses are biased by 1. + addrState.mu.refs++ + addrState.mu.kind = Permanent + } + // Acquire the address before returning it. + addrState.mu.refs++ + addrState.mu.deprecated = deprecated + addrState.mu.configType = configType + + if attemptAddToPrimary { + switch peb { + case NeverPrimaryEndpoint: + case CanBePrimaryEndpoint: + a.mu.primary = append(a.mu.primary, addrState) + case FirstPrimaryEndpoint: + if cap(a.mu.primary) == len(a.mu.primary) { + a.mu.primary = append([]*addressState{addrState}, a.mu.primary...) + } else { + // Shift all the endpoints by 1 to make room for the new address at the + // front. We could have just created a new slice but this saves + // allocations when the slice has capacity for the new address. + primaryCount := len(a.mu.primary) + a.mu.primary = append(a.mu.primary, nil) + if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount { + panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount)) + } + a.mu.primary[0] = addrState + } + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + } + + return addrState, nil +} + +// RemovePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.mu.groups[addr]; ok { + panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) + } + + return a.removePermanentAddressLocked(addr) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + addrState, ok := a.mu.endpoints[addr] + if !ok { + return tcpip.ErrBadLocalAddress + } + + return a.removePermanentEndpointLocked(addrState) +} + +// RemovePermanentEndpoint removes the passed endpoint if it is associated with +// a and permanent. +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { + addrState, ok := ep.(*addressState) + if !ok || addrState.addressableEndpointState != a { + return tcpip.ErrInvalidEndpointState + } + + return a.removePermanentEndpointLocked(addrState) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { + if !addrState.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + addrState.SetKind(PermanentExpired) + a.decAddressRefLocked(addrState) + return nil +} + +// decAddressRef decrements the address's reference count and releases it once +// the reference count hits 0. +func (a *AddressableEndpointState) decAddressRef(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.decAddressRefLocked(addrState) +} + +// decAddressRefLocked is like decAddressRef but with locking requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) { + addrState.mu.Lock() + defer addrState.mu.Unlock() + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr)) + } + + addrState.mu.refs-- + + if addrState.mu.refs != 0 { + return + } + + // A non-expired permanent address must not have its reference count dropped + // to 0. + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind)) + } + + a.releaseAddressStateLocked(addrState) +} + +// MainAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.GetKind() == Permanent + }) + if ep == nil { + return tcpip.AddressWithPrefix{} + } + + addr := ep.AddressWithPrefix() + a.decAddressRefLocked(ep) + return addr +} + +// acquirePrimaryAddressRLocked returns an acquired primary address that is +// valid according to isValid. +// +// Precondition: e.mu must be read locked +func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState { + var deprecatedEndpoint *addressState + for _, ep := range a.mu.primary { + if !isValid(ep) { + continue + } + + if !ep.Deprecated() { + if ep.IncRef() { + // ep is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + a.decAddressRefLocked(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return ep + } + } else if deprecatedEndpoint == nil && ep.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of + // ep in case a doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, ep's reference count + // will be decremented. + deprecatedEndpoint = ep + } + } + + return deprecatedEndpoint +} + +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + a.mu.Lock() + defer a.mu.Unlock() + + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } + + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } + + return addrState + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + if err != nil { + // addAndAcquireAddressLocked only returns an error if the address is + // already assigned but we just checked above if the address exists so we + // expect no error. + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + } + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + return ep +} + +// AcquireOutgoingPrimaryAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.IsAssigned(allowExpired) + }) + + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + + return ep +} + +// PrimaryAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.primary { + // Don't include tentative, expired or temporary endpoints + // to avoid confusion and prevent the caller from using + // those. + switch ep.GetKind() { + case PermanentTentative, PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// PermanentAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.endpoints { + if !ep.GetKind().IsPermanent() { + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// JoinGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) + if err != nil { + return false, err + } + // We have no need for the address endpoint. + a.decAddressRefLocked(ep) + } + + a.mu.groups[group] = joins + 1 + return !ok, nil +} + +// LeaveGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + return false, tcpip.ErrBadLocalAddress + } + + if joins == 1 { + a.removeGroupAddressLocked(group) + delete(a.mu.groups, group) + return true, nil + } + + a.mu.groups[group] = joins - 1 + return false, nil +} + +// IsInGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { + a.mu.RLock() + defer a.mu.RUnlock() + _, ok := a.mu.groups[group] + return ok +} + +func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { + if err := a.removePermanentAddressLocked(group); err != nil { + // removePermanentEndpointLocked would only return an error if group is + // not bound to the addressable endpoint, but we know it MUST be assigned + // since we have group in our map of groups. + panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) + } +} + +// Cleanup forcefully leaves all groups and removes all permanent addresses. +func (a *AddressableEndpointState) Cleanup() { + a.mu.Lock() + defer a.mu.Unlock() + + for group := range a.mu.groups { + a.removeGroupAddressLocked(group) + } + a.mu.groups = make(map[tcpip.Address]uint32) + + for _, ep := range a.mu.endpoints { + // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // not a permanent address. + if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) + } + } +} + +var _ AddressEndpoint = (*addressState)(nil) + +// addressState holds state for an address. +type addressState struct { + addressableEndpointState *AddressableEndpointState + addr tcpip.AddressWithPrefix + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + refs uint32 + kind AddressKind + configType AddressConfigType + deprecated bool + } +} + +// NetworkEndpoint implements AddressEndpoint. +func (a *addressState) NetworkEndpoint() NetworkEndpoint { + return a.addressableEndpointState.networkEndpoint +} + +// AddressWithPrefix implements AddressEndpoint. +func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { + return a.addr +} + +// GetKind implements AddressEndpoint. +func (a *addressState) GetKind() AddressKind { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.kind +} + +// SetKind implements AddressEndpoint. +func (a *addressState) SetKind(kind AddressKind) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.kind = kind +} + +// IsAssigned implements AddressEndpoint. +func (a *addressState) IsAssigned(allowExpired bool) bool { + if !a.addressableEndpointState.networkEndpoint.Enabled() { + return false + } + + switch a.GetKind() { + case PermanentTentative: + return false + case PermanentExpired: + return allowExpired + default: + return true + } +} + +// IncRef implements AddressEndpoint. +func (a *addressState) IncRef() bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.mu.refs == 0 { + return false + } + + a.mu.refs++ + return true +} + +// DecRef implements AddressEndpoint. +func (a *addressState) DecRef() { + a.addressableEndpointState.decAddressRef(a) +} + +// ConfigType implements AddressEndpoint. +func (a *addressState) ConfigType() AddressConfigType { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.configType +} + +// SetDeprecated implements AddressEndpoint. +func (a *addressState) SetDeprecated(d bool) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.deprecated = d +} + +// Deprecated implements AddressEndpoint. +func (a *addressState) Deprecated() bool { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.deprecated +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go new file mode 100644 index 000000000..26787d0a3 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestAddressableEndpointStateCleanup tests that cleaning up an addressable +// endpoint state removes permanent addresses and leaves groups. +func TestAddressableEndpointStateCleanup(t *testing.T) { + var ep fakeNetworkEndpoint + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + var s stack.AddressableEndpointState + s.Init(&ep) + + addr := tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + } + + { + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + } + // We don't need the address endpoint. + ep.DecRef() + } + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep == nil { + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) + } + ep.DecRef() + } + + group := tcpip.Address("\x02") + if added, err := s.JoinGroup(group); err != nil { + t.Fatalf("s.JoinGroup(%s): %s", group, err) + } else if !added { + t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) + } + if !s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) + } + + s.Cleanup() + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) + } + } + if s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + } +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 836682ea0..0cd1da11f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -196,13 +196,14 @@ type bucket struct { // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. +// +// Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := pkt.Network() + if netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol @@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { dstAddr: netHeader.DestinationAddress(), dstPort: tcpHeader.DestinationPort(), transProto: netHeader.TransportProtocol(), - netProto: header.IPv4ProtocolNumber, + netProto: pkt.NetworkProtocolNumber, }, nil } @@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.MinIP - replyTID.srcPort = rt.MinPort + replyTID.srcAddr = rt.Addr + replyTID.srcPort = rt.Port var manip manipType switch hook { case Prerouting: @@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction @@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { // support cases when they are validated, e.g. when we can't offload // receive checksumming. - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacketOutput manipulates ports for packets in Output hook. @@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction @@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacket will manipulate the port and address of the packet if the @@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return } @@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1 dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, transProto: header.TCPProtocolNumber, - netProto: header.IPv4ProtocolNumber, + netProto: netProto, } conn, _ := ct.connForTID(tid) if conn == nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 54759091a..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,6 +46,8 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { + AddressableEndpointState + nicID tcpip.NICID proto *fwdTestNetworkProtocol dispatcher TransportDispatcher @@ -53,12 +56,18 @@ type fwdTestNetworkEndpoint struct { var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil +} + +func (*fwdTestNetworkEndpoint) Enabled() bool { + return true } -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (*fwdTestNetworkEndpoint) Disable() {} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { @@ -78,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr return 0 } -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -106,7 +111,9 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu return tcpip.ErrNotSupported } -func (*fwdTestNetworkEndpoint) Close() {} +func (f *fwdTestNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. @@ -116,6 +123,11 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) + + mu struct { + sync.RWMutex + forwarding bool + } } var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) @@ -145,13 +157,15 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { - return &fwdTestNetworkEndpoint{ - nicID: nicID, +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { + e := &fwdTestNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } + e.AddressableEndpointState.Init(e) + return e } func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { @@ -186,6 +200,21 @@ func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding + +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + // fwdTestPacketInfo holds all the information about an outbound packet. type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress @@ -307,7 +336,7 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, UseNeighborCache: useNeighborCache, }) @@ -316,7 +345,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 0e33cbe92..8d6d9a7f1 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -57,14 +57,14 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - tables: [numTables]Table{ + v4Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -83,9 +83,9 @@ func DefaultTables() *IPTables { }, mangleID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -101,10 +101,75 @@ func DefaultTables() *IPTables { }, filterID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -166,25 +231,20 @@ func EmptyNATTable() Table { // GetTable returns a table by name. func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return Table{}, false - } id, ok := nameToID[name] if !ok { return Table{}, false } it.mu.RLock() defer it.mu.RUnlock() - return it.tables[id], true + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return tcpip.ErrInvalidOptionValue - } id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue @@ -198,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Err it.startReaper(reaperDelay) } it.modified = true - it.tables[id] = table + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } return nil } @@ -221,8 +285,15 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() @@ -243,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr if tableID == natID && pkt.NatDone { continue } - table := it.tables[tableID] + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -256,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -351,11 +427,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -372,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -398,11 +474,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -421,16 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, tcpip.ErrNotConnected } - return it.connections.originalDst(epID) + return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 5f1b2af64..538c4625d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -21,78 +21,139 @@ import ( ) // AcceptTarget accepts packets. -type AcceptTarget struct{} +type AcceptTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (at *AcceptTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: at.NetworkProtocol, + } +} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } // DropTarget drops packets. -type DropTarget struct{} +type DropTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (dt *DropTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: dt.NetworkProtocol, + } +} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. -type ErrorTarget struct{} +type ErrorTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (et *ErrorTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: et.NetworkProtocol, + } +} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. type UserChainTarget struct { + // Name is the chain name. Name string + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (uc *UserChainTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: uc.NetworkProtocol, + } } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } // ReturnTarget returns from the current chain. If the chain is a built-in, the // hook's underflow should be called. -type ReturnTarget struct{} +type ReturnTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (rt *ReturnTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: rt.NetworkProtocol, + } +} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const RedirectTargetName = "REDIRECT" + // RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. +// TODO(gvisor.dev/issue/170): Other flags need to be added after we support +// them. type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool + // Addr indicates address used to redirect. + Addr tcpip.Address - // MinIP indicates address used to redirect. - MinIP tcpip.Address + // Port indicates port used to redirect. + Port uint16 - // MaxIP indicates address used to redirect. - MaxIP tcpip.Address - - // MinPort indicates port used to redirect. - MinPort uint16 + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} - // MaxPort indicates port used to redirect. - MaxPort uint16 +// ID implements Target.ID. +func (rt *RedirectTarget) ID() TargetID { + return TargetID{ + Name: RedirectTargetName, + NetworkProtocol: rt.NetworkProtocol, + } } // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -103,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1) in Output and - // to primary address of the incoming interface in Prerouting. + // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // the primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) - rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + } else { + rt.Addr = header.IPv6Loopback + } case Prerouting: - rt.MinIP = address - rt.MaxIP = address + rt.Addr = address default: panic("redirect target is supported only on output and prerouting hooks") } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - switch protocol := netHeader.TransportProtocol(); protocol { + switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.MinPort) + udpHeader.SetDestinationPort(rt.Port) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(protocol, length) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) @@ -139,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - // Change destination address. - netHeader.SetDestinationAddress(rt.MinIP) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + + pkt.Network().SetDestinationAddress(rt.Addr) + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fbbd2f50f..7b3f3e88b 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -81,31 +82,42 @@ const ( // // +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps tableIDs to tables. Holds builtin tables only, not user - // tables. mu must be locked for accessing. - tables [numTables]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities [NumHooks][]tableID - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack - // reaperDone can be signalled to stop the reaper goroutine. + // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} } -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules. +// A Table defines a set of chains and hooks into the network stack. +// +// It is a list of Rules, entry points (BuiltinChains), and error handlers +// (Underflows). As packets traverse netstack, they hit hooks. When a packet +// hits a hook, iptables compares it to Rules starting from that hook's entry +// point. So if a packet hits the Input hook, we look up the corresponding +// entry point in BuiltinChains and jump to that point. +// +// If the Rule doesn't match the packet, iptables continues to the next Rule. +// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept +// or RuleDrop) that causes the packet to stop traversing iptables. It can also +// jump to other rules or perform custom actions based on Rule.Target. +// +// Underflow Rules are invoked when a chain returns without reaching a verdict. // // +stateify savable type Table struct { @@ -148,7 +160,7 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. // // +stateify savable type IPHeaderFilter struct { @@ -196,16 +208,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } @@ -233,6 +272,18 @@ func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool return true } +// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header +// applies. +func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber { + switch len(fl.Src) { + case header.IPv4AddressSize: + return header.IPv4ProtocolNumber + case header.IPv6AddressSize: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src)) +} + // filterAddress returns whether addr matches the filter. func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { matches := true @@ -258,8 +309,23 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } +// A TargetID uniquely identifies a target. +type TargetID struct { + // Name is the target name as stored in the xt_entry_target struct. + Name string + + // NetworkProtocol is the protocol to which the target applies. + NetworkProtocol tcpip.NetworkProtocolNumber + + // Revision is the version of the target. + Revision uint8 +} + // A Target is the interface for taking an action for a packet. type Target interface { + // ID uniquely identifies the Target. + ID() TargetID + // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 14fb4239b..33806340e 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math" "sync/atomic" "testing" "time" @@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) { } func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + // There is a race condition causing this test to fail when the executor + // takes longer than the resolution timeout to call linkAddrCache.get. This + // is especially common when this test is run with gotsan. + // + // Using a large resolution timeout decreases the probability of experiencing + // this race condition and does not affect how long this test takes to run. + c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67dc5364f..73a01c2dd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -150,10 +150,10 @@ type ndpDNSSLEvent struct { type ndpDHCPv6Event struct { nicID tcpip.NICID - configuration stack.DHCPv6ConfigurationFromNDPRA + configuration ipv6.DHCPv6ConfigurationFromNDPRA } -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. @@ -170,7 +170,7 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ @@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } } -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip. return n.rememberRouter } -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip } } -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip return n.rememberPrefix } -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi } } -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { if c := n.rdnssC; c != nil { c <- ndpRDNSSEvent{ @@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } -// Implements stack.NDPDispatcher.OnDNSSearchListOption. +// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { if n.dnsslC != nil { n.dnsslC <- ndpDNSSLEvent{ @@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s } } -// Implements stack.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { +// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { c <- ndpDHCPv6Event{ nicID, @@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + RetransmitTimer: test.retransTimer, + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + }, + })}, + }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -558,6 +559,26 @@ func TestDADResolve(t *testing.T) { } } +func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) +} + // TestDADFail tests to make sure that the DAD process fails if another node is // detected to be performing DAD on the same address (receive an NS message from // a node doing DAD for the same address), or if another node is detected to own @@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) { tests := []struct { name string - makeBuf func(tgt tcpip.Address) buffer.Prependable + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RxSolicit", + rxPkt: rxNDPSolicit, + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, }, { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { + name: "RxAdvert", + rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) @@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) { SrcAddr: tgt, DstAddr: header.IPv6AllNodesMulticastAddress, }) - - return hdr - + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, }, @@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 + ndpConfigs := ipv6.DefaultNDPConfigurations() + ndpConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -664,13 +663,8 @@ func TestDADFail(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv6ProtocolNumber, pkt) + // Receive a packet to simulate an address conflict. + test.rxPkt(e, addr1) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -754,18 +748,19 @@ func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.NDPConfigurations{ + + ndpConfigs := ipv6.NDPConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -815,19 +810,6 @@ func TestDADStop(t *testing.T) { } } -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - // TestSetNDPConfigurations tests that we can update and use per-interface NDP // configurations without affecting the default NDP configurations or other // interfaces' configurations. @@ -863,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -892,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) { } // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ + configs := ipv6.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(nicID1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(configs) } // Created after updating NIC(1)'s NDP configurations @@ -1113,14 +1099,15 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1151,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1192,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func(addr tcpip.Address, discovered bool) { @@ -1285,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) { } // TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), @@ -1293,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1306,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - if i <= stack.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredDefaultRouters { select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, llAddr, true); diff != "" { @@ -1358,14 +1348,15 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1399,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1445,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1545,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1621,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) } - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} prefixAddr[7] = byte(i) prefix := tcpip.AddressWithPrefix{ @@ -1665,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < ipv6.MaxDiscoveredOnLinkPrefixes { select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { @@ -1716,14 +1711,15 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1749,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr(t *testing.T) { +func TestAutoGenAddr2(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1766,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1876,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := stack.MaxDesyncFactor + savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MaxDesyncFactor = time.Nanosecond prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1931,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2119,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond tests := []struct { name string @@ -2160,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2211,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = 0 + ipv6.MaxDesyncFactor = 0 prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2228,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2294,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2317,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2382,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Wait for all the temporary addresses to get invalidated. @@ -2439,17 +2443,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2462,16 +2466,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2545,9 +2550,12 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // as paased. ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) select { case e := <-ndpDisp.autoGenAddrC: @@ -2565,9 +2573,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) - } + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } @@ -2655,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) s.SetRouteTable([]tcpip.Route{{ @@ -2739,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { ndpDisp.dadC = make(chan ndpDADEvent, 2) ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits ndpConfigs.RetransmitTimer = retransmitTimer - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Do SLAAC for prefix. @@ -2754,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // DAD failure to restart the local generation process. addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] expectAutoGenAddrAsyncEvent(addr, newAddr) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { @@ -2794,14 +2802,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - UseNeighborCache: useNeighborCache, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3036,11 +3045,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -3258,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 const minVLSeconds = 1 savedIL := header.NDPInfiniteLifetime - savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL header.NDPInfiniteLifetime = savedIL }() - stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3307,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3357,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 const newMinVL = 4 - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3449,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3515,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3700,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3781,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3856,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const lifetimeSeconds = 10 // Needed for the temporary address sub test. - savedMaxDesync := stack.MaxDesyncFactor + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] @@ -3938,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3963,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -3977,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "Temporary address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4029,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, }, - SecretKey: secretKey, - }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4059,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) expectDADEvent(t, &ndpDisp, addr.Address, false) @@ -4119,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool subnet tcpip.Subnet triggerSLAACFn func(e *channel.Endpoint) }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4142,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, @@ -4165,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4198,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4250,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4296,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Simulate a DAD conflict after some time has passed. time.Sleep(failureTimer) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4459,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4509,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4640,7 +4653,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -4694,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func() (bool, ndpRouterEvent) { @@ -4967,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { t.Helper() select { case e := <-ndpDisp.dhcpv6ConfigurationC: @@ -5002,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Even if the first RA reports no DHCPv6 configurations are available, the // dispatcher should get an event. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) // Receiving the same update again should not result in an event to the // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) @@ -5011,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to Managed Address. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to none. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -5031,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // // Note, when the M flag is set, the O flag is redundant. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) expectNoDHCPv6Event() // Even though the DHCPv6 flags are different, the effective configuration is @@ -5044,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -5059,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() } @@ -5217,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5286,11 +5302,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, @@ -5357,12 +5373,13 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b4fa69e3e..a0b7da5cd 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -239,7 +240,7 @@ type entryEvent struct { func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) if got, want := neigh.config(), c; got != want { @@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) { func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) c.MinRandomFactor = 1 @@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) { func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } type testContext struct { - clock *fakeClock + clock *faketime.ManualClock neigh *neighborCache store *testEntryStore linkRes *testNeighborResolver @@ -418,7 +419,7 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if doneCh == nil { t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { // Remove the waker before the neighbor cache has the opportunity to send a // notification. neigh.removeWaker(entry.Addr, &w) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { config.MaxRandomFactor = 1 nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Wait() // Process all the requests for a single entry concurrently - clock.advance(typicalLatency) + clock.Advance(typicalLatency) } // All goroutines add in the same order and add more values than can fit in @@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(config.DelayFirstProbeTime + typicalLatency) + clock.Advance(config.DelayFirstProbeTime + typicalLatency) select { case <-doneCh: default: @@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) { // Verify the entry's new link address { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } @@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() @@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) if err != nil { t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) @@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config := DefaultNUDConfigurations() config.RetransmitTimer = time.Millisecond // small enough to cause timeout - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // resolved immediately and don't send resolution requests. func TestNeighborCacheStaticResolution(t *testing.T) { config := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 0068cacb8..9a72bec79 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // NeighborEntry describes a neighboring device in the local network. @@ -73,8 +74,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - protocol tcpip.NetworkProtocolNumber + nic *NIC // linkRes provides the functionality to send reachability probes, used in // Neighbor Unreachability Detection. @@ -440,7 +440,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.notifyWakersLocked() } - if e.isRouter && !flags.IsRouter { + if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { // "In those cases where the IsRouter flag changes from TRUE to FALSE as // a result of this update, the node MUST remove that router from the // Default Router List and update the Destination Cache entries for all @@ -448,9 +448,17 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // 7.3.3. This is needed to detect when a node that is used as a router // stops forwarding packets due to being configured as a host." // - RFC 4861 section 7.2.5 - e.nic.mu.Lock() - e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) - e.nic.mu.Unlock() + // + // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 + // here. + ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + if !ok { + panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) + } + + if ndpEP, ok := ep.(NDPEndpoint); ok { + ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + } } e.isRouter = flags.IsRouter diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index b769fb2fa..a265fff0a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -27,6 +27,8 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -221,8 +223,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return entryTestNetNumber } -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { - clock := newFakeClock() +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { + clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ id: entryTestNICID, @@ -232,18 +234,15 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudDisp: &disp, }, } + nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) - // Stub out ndpState to verify modification of default routers. - nic.mu.ndp = ndpState{ - nic: &nic, - defaultRouters: make(map[tcpip.Address]defaultRouterState), - } - // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ nic: &nic, @@ -267,7 +266,7 @@ func TestEntryInitiallyUnknown(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // No probes should have been sent. linkRes.mu.Lock() @@ -300,7 +299,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { } e.mu.Unlock() - clock.advance(time.Hour) + clock.Advance(time.Hour) // No probes should have been sent. linkRes.mu.Lock() @@ -410,7 +409,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { updatedAt := e.neigh.UpdatedAt e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should remain the same during address resolution. wantProbes := []entryTestProbeInfo{ @@ -439,7 +438,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after // sending the last probe transitions the entry to Failed. @@ -459,7 +458,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } } - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) wantEvents := []testEntryEventInfo{ { @@ -748,7 +747,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The Incomplete-to-Incomplete state transition is tested here by @@ -816,6 +815,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) + ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + e.mu.Lock() e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ @@ -829,9 +830,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ - invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), - } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -840,8 +839,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, false; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { - t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + if ipv6EP.invalidatedRtr != e.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) } e.mu.Unlock() @@ -983,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1612,7 +1611,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1706,7 +1705,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1989,7 +1988,7 @@ func TestEntryDelayToProbe(t *testing.T) { } e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2069,7 +2068,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2166,7 +2165,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2267,7 +2266,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2364,7 +2363,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // Probe caused by the Delay-to-Probe transition @@ -2398,7 +2397,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2463,7 +2462,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2503,7 +2502,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2575,7 +2574,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2612,7 +2611,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2682,7 +2681,7 @@ func TestEntryProbeToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2787,7 +2786,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 863ef6bee..06824843a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -18,7 +18,6 @@ import ( "fmt" "math/rand" "reflect" - "sort" "sync/atomic" "gvisor.dev/gvisor/pkg/sleep" @@ -28,13 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -var ipv4BroadcastAddr = tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 8 * header.IPv4AddressSize, - }, -} +var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. @@ -49,18 +42,18 @@ type NIC struct { neigh *neighborCache networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + mu struct { sync.RWMutex - enabled bool spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - ndp ndpState } } @@ -84,25 +77,6 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } -// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. -type PrimaryEndpointBehavior int - -const ( - // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. - CanBePrimaryEndpoint PrimaryEndpointBehavior = iota - - // FirstPrimaryEndpoint indicates the endpoint should be the first - // primary endpoint considered. If there are multiple endpoints with - // this behavior, the most recently-added one will be first. - FirstPrimaryEndpoint - - // NeverPrimaryEndpoint indicates the endpoint should never be a - // primary endpoint. - NeverPrimaryEndpoint -) - // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -122,19 +96,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) - nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) - nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - nic.mu.ndp.initializeTempAddrState() // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -162,7 +124,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = nil - nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } nic.linkEP.Attach(nic) @@ -170,29 +132,28 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enabled returns true if n is enabled. -func (n *NIC) enabled() bool { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - return enabled +// Enabled implements NetworkInterface. +func (n *NIC) Enabled() bool { + return atomic.LoadUint32(&n.enabled) == 1 } -// disable disables n. +// setEnabled sets the enabled status for the NIC. // -// It undoes the work done by enable. -func (n *NIC) disable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if !enabled { - return nil +// Returns true if the enabled status was updated. +func (n *NIC) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&n.enabled, 1) == 0 } + return atomic.SwapUint32(&n.enabled, 0) == 1 +} +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() { n.mu.Lock() - err := n.disableLocked() + n.disableLocked() n.mu.Unlock() - return err } // disableLocked disables n. @@ -200,9 +161,9 @@ func (n *NIC) disable() *tcpip.Error { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() *tcpip.Error { - if !n.mu.enabled { - return nil +func (n *NIC) disableLocked() { + if !n.setEnabled(false) { + return } // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be @@ -210,38 +171,9 @@ func (n *NIC) disableLocked() *tcpip.Error { // again, and applications may not know that the underlying NIC was ever // disabled. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - n.mu.ndp.stopSolicitingRouters() - n.mu.ndp.cleanupState(false /* hostOnly */) - - // Stop DAD for all the unicast IPv6 endpoints that are in the - // permanentTentative state. - for _, r := range n.mu.endpoints { - if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } - } - - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - } - - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - - // The address may have already been removed. - if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } + for _, ep := range n.networkEndpoints { + ep.Disable() } - - n.mu.enabled = false - return nil } // enable enables n. @@ -251,162 +183,39 @@ func (n *NIC) disableLocked() *tcpip.Error { // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if enabled { - return nil - } - n.mu.Lock() defer n.mu.Unlock() - if n.mu.enabled { - return nil - } - - n.mu.enabled = true - - // Create an endpoint to receive broadcast packets on this interface. - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } - - // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts - // multicast group. Note, the IANA calls the all-hosts multicast group the - // all-systems multicast group. - if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { - return err - } - } - - // Join the IPv6 All-Nodes Multicast group if the stack is configured to - // use IPv6. This is required to ensure that this node properly receives - // and responds to the various NDP messages that are destined to the - // all-nodes multicast address. An example is the Neighbor Advertisement - // when we perform Duplicate Address Detection, or Router Advertisement - // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 - // section 4.2 for more information. - // - // Also auto-generate an IPv6 link-local address based on the NIC's - // link address if it is configured to do so. Note, each interface is - // required to have IPv6 link-local unicast address, as per RFC 4291 - // section 2.1. - _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] - if !ok { + if !n.setEnabled(true) { return nil } - // Join the All-Nodes multicast group before starting DAD as responses to DAD - // (NDP NS) messages may be sent to the All-Nodes multicast group if the - // source address of the NDP NS is the unspecified address, as per RFC 4861 - // section 7.2.4. - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - - // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent - // state. - // - // Addresses may have aleady completed DAD but in the time since the NIC was - // last enabled, other devices may have acquired the same addresses. - for _, r := range n.mu.endpoints { - addr := r.address() - if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { - continue - } - - r.setKind(permanentTentative) - if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + for _, ep := range n.networkEndpoints { + if err := ep.Enable(); err != nil { return err } } - // Do not auto-generate an IPv6 link-local address for loopback devices. - if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - // The valid and preferred lifetime is infinite for the auto-generated - // link-local address. - n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) - } - - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyways. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !n.stack.forwarding { - n.mu.ndp.startSolicitingRouters() - } - return nil } -// remove detaches NIC from the link endpoint, and marks existing referenced -// network endpoints expired. This guarantees no packets between this NIC and -// the network stack. +// remove detaches NIC from the link endpoint and releases network endpoint +// resources. This guarantees no packets between this NIC and the network +// stack. func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() n.disableLocked() - // TODO(b/151378115): come up with a better way to pick an error than the - // first one. - var err *tcpip.Error - - // Forcefully leave multicast groups. - for nid := range n.mu.mcastJoins { - if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { - err = tempErr - } - } - - // Remove permanent and permanentTentative addresses, so no packet goes out. - for nid, ref := range n.mu.endpoints { - switch ref.getKind() { - case permanentTentative, permanent: - if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { - err = tempErr - } - } - } - - // Release any resources the network endpoint may hold. for _, ep := range n.networkEndpoints { ep.Close() } + n.networkEndpoints = nil // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) - - return err -} - -// becomeIPv6Router transitions n into an IPv6 router. -// -// When transitioning into an IPv6 router, host-only state (NDP discovered -// routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated and NDP router solicitations will be stopped. -func (n *NIC) becomeIPv6Router() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.cleanupState(true /* hostOnly */) - n.mu.ndp.stopSolicitingRouters() -} - -// becomeIPv6Host transitions n into an IPv6 host. -// -// When transitioning into an IPv6 host, NDP router solicitations will be -// started. -func (n *NIC) becomeIPv6Host() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.startSolicitingRouters() + return nil } // setPromiscuousMode enables or disables promiscuous mode. @@ -423,7 +232,8 @@ func (n *NIC) isPromiscuousMode() bool { return rv } -func (n *NIC) isLoopback() bool { +// IsLoopback implements NetworkInterface. +func (n *NIC) IsLoopback() bool { return n.linkEP.Capabilities()&CapabilityLoopback != 0 } @@ -434,206 +244,44 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists for the given protocol and remoteAddr. If no non-deprecated -// endpoint exists, the first deprecated endpoint will be returned. -// -// If an IPv6 primary endpoint is requested, Source Address Selection (as -// defined by RFC 6724 section 5) will be performed. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 { - return n.primaryIPv6Endpoint(remoteAddr) - } - - n.mu.RLock() - defer n.mu.RUnlock() - - var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.mu.primary[protocol] { - if !r.isValidForOutgoingRLocked() { - continue - } - - if !r.deprecated { - if r.tryIncRef() { - // r is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - deprecatedEndpoint.decRefLocked() - deprecatedEndpoint = nil - } - - return r - } - } else if deprecatedEndpoint == nil && r.tryIncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of r in - // case n doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, r's reference count - // will be decremented when such an endpoint is found. - deprecatedEndpoint = r - } - } - - // n doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated - // endpoints either). - return deprecatedEndpoint -} - -// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC -// 6724 section 5). -type ipv6AddrCandidate struct { - ref *referencedNetworkEndpoint - scope header.IPv6AddressScope -} - -// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { +// primaryAddress returns an address that can be used to communicate with +// remoteAddr. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { n.mu.RLock() - ref := n.primaryIPv6EndpointRLocked(remoteAddr) + spoofing := n.mu.spoofing n.mu.RUnlock() - return ref -} - -// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -// -// n.mu MUST be read locked. -func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] - - if len(primaryAddrs) == 0 { - return nil - } - - // Create a candidate set of available addresses we can potentially use as a - // source address. - cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) - for _, r := range primaryAddrs { - // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoingRLocked() { - continue - } - - addr := r.address() - scope, err := header.ScopeForIPv6Address(addr) - if err != nil { - // Should never happen as we got r from the primary IPv6 endpoint list and - // ScopeForIPv6Address only returns an error if addr is not an IPv6 - // address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) - } - - cs = append(cs, ipv6AddrCandidate{ - ref: r, - scope: scope, - }) - } - - remoteScope, err := header.ScopeForIPv6Address(remoteAddr) - if err != nil { - // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) - } - - // Sort the addresses as per RFC 6724 section 5 rules 1-3. - // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. - sort.Slice(cs, func(i, j int) bool { - sa := cs[i] - sb := cs[j] - - // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.address() == remoteAddr { - return true - } - if sb.ref.address() == remoteAddr { - return false - } - - // Prefer appropriate scope as per RFC 6724 section 5 rule 2. - if sa.scope < sb.scope { - return sa.scope >= remoteScope - } else if sb.scope < sa.scope { - return sb.scope < remoteScope - } - - // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. - if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { - // If sa is not deprecated, it is preferred over sb. - return sbDep - } - - // Prefer temporary addresses as per RFC 6724 section 5 rule 7. - if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { - return saTemp - } - - // sa and sb are equal, return the endpoint that is closest to the front of - // the primary endpoint list. - return i < j - }) - - // Return the most preferred address that can have its reference count - // incremented. - for _, c := range cs { - if r := c.ref; r.tryIncRef() { - return r - } - } - - return nil -} - -// hasPermanentAddrLocked returns true if n has a permanent (including currently -// tentative) address, addr. -func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] + ep, ok := n.networkEndpoints[protocol] if !ok { - return false + return nil } - kind := ref.getKind() - - return kind == permanent || kind == permanentTentative + return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } -type getRefBehaviour int +type getAddressBehaviour int const ( // spoofing indicates that the NIC's spoofing flag should be observed when - // getting a NIC's referenced network endpoint. - spoofing getRefBehaviour = iota + // getting a NIC's address endpoint. + spoofing getAddressBehaviour = iota // promiscuous indicates that the NIC's promiscuous flag should be observed - // when getting a NIC's referenced network endpoint. + // when getting a NIC's address endpoint. promiscuous ) -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) +func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, spoofing) +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } -// getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. +// getAddressEpOrCreateTemp returns the address endpoint for the given protocol +// and address. // // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. @@ -641,9 +289,8 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // // If the address is the IPv4 broadcast address for an endpoint's network, that // endpoint will be returned. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { +func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() - var spoofingOrPromiscuous bool switch tempRef { case spoofing: @@ -651,294 +298,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous } - - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // An endpoint with this id exists, check if it can be used and return it. - if !ref.isAssignedRLocked(spoofingOrPromiscuous) { - n.mu.RUnlock() - return nil - } - - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - } - - // Check if address is a broadcast address for the endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. - if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { - n.mu.RUnlock() - return ref - } - } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the IPv4 address is found in the NIC's subnets and the NIC - // is a loopback interface. - createTempEP := spoofingOrPromiscuous - if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber { - for _, r := range n.mu.endpoints { - addr := r.addrWithPrefix() - subnet := addr.Subnet() - if subnet.Contains(address) { - createTempEP = true - break - } - } - } n.mu.RUnlock() - - if !createTempEP { - return nil - } - - // Try again with the lock in exclusive mode. If we still can't get the - // endpoint, create a new "temporary" endpoint. It will only exist while - // there's a route through it. - n.mu.Lock() - ref := n.getRefOrCreateTempLocked(protocol, address, peb) - n.mu.Unlock() - return ref + return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb) } -// getRefForBroadcastLocked returns an endpoint where address is the IPv4 -// broadcast address for the endpoint's network. -// -// n.mu MUST be read locked. -func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { - for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses. - if ref.protocol != header.IPv4ProtocolNumber { - continue - } - - addr := ref.addrWithPrefix() - subnet := addr.Subnet() - if subnet.IsBroadcast(address) && ref.tryIncRef() { - return ref - } +// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean +// is passed to indicate whether or not we should generate temporary endpoints. +func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + if ep, ok := n.networkEndpoints[protocol]; ok { + return ep.AcquireAssignedAddress(address, createTemp, peb) } return nil } -/// getRefOrCreateTempLocked returns an existing endpoint for address or creates -/// and returns a temporary endpoint. -// -// If the address is the IPv4 broadcast address for an endpoint's network, that -// endpoint will be returned. -// -// n.mu must be write locked. -func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // No need to check the type as we are ok with expired endpoints at this - // point. - if ref.tryIncRef() { - return ref - } - // tryIncRef failing means the endpoint is scheduled to be removed once the - // lock is released. Remove it here so we can create a new (temporary) one. - // The removal logic waiting for the lock handles this case. - n.removeEndpointLocked(ref) - } - - // Check if address is a broadcast address for an endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. - if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { - return ref - } - } - - // Add a new temporary endpoint. - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return nil - } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, temporary, static, false) - return ref -} - -// addAddressLocked adds a new protocolAddress to n. -// -// If n already has the address in a non-permanent state, and the kind given is -// permanent, that address will be promoted in place and its properties set to -// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP addresses before adding them. - - // Sanity check. - id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.mu.endpoints[id]; ok { - // Endpoint already exists. - if kind != permanent { - return nil, tcpip.ErrDuplicateAddress - } - switch ref.getKind() { - case permanentTentative, permanent: - // The NIC already have a permanent endpoint with that address. - return nil, tcpip.ErrDuplicateAddress - case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect the new peb, - // configType and deprecated status. - if ref.tryIncRef() { - // TODO(b/147748385): Perform Duplicate Address Detection when promoting - // an IPv6 endpoint to permanent. - ref.setKind(permanent) - ref.deprecated = deprecated - ref.configType = configType - - refs := n.mu.primary[ref.protocol] - for i, r := range refs { - if r == ref { - switch peb { - case CanBePrimaryEndpoint: - return ref, nil - case FirstPrimaryEndpoint: - if i == 0 { - return ref, nil - } - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - case NeverPrimaryEndpoint: - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - return ref, nil - } - } - } - - n.insertPrimaryEndpointLocked(ref, peb) - - return ref, nil - } - // tryIncRef failing means the endpoint is scheduled to be removed once - // the lock is released. Remove it here so we can create a new - // (permanent) one. The removal logic waiting for the lock handles this - // case. - n.removeEndpointLocked(ref) - } - } - +// addAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol - } - - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) - - // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process if the NIC is - // enabled. If the NIC is not enabled, DAD will be started when the NIC is - // enabled. - if isIPv6Unicast && kind == permanent { - kind = permanentTentative - } - - ref := &referencedNetworkEndpoint{ - refs: 1, - addr: protocolAddress.AddressWithPrefix, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, - configType: configType, - deprecated: deprecated, - } - - // Set up resolver if link address resolution exists for this protocol. - if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { - ref.linkCache = n.stack - ref.linkRes = linkRes - } - } - - // If we are adding an IPv6 unicast address, join the solicited-node - // multicast address. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } + return tcpip.ErrUnknownProtocol } - n.mu.endpoints[id] = ref - - n.insertPrimaryEndpointLocked(ref, peb) - - // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. - if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { - if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { - return nil, err - } + addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + if err == nil { + // We have no need for the address endpoint. + addressEndpoint.DecRef() } - - return ref, nil -} - -// AddAddress adds a new address to n, so that it starts accepting packets -// targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { - // Add the endpoint. - n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) - n.mu.Unlock() - return err } -// AllAddresses returns all addresses (primary and non-primary) associated with +// allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - - addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for _, ref := range n.mu.endpoints { - // Don't include tentative, expired or temporary endpoints to - // avoid confusion and prevent the caller from using those. - switch ref.getKind() { - case permanentExpired, temporary: - continue +func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { + var addrs []tcpip.ProtocolAddress + for p, ep := range n.networkEndpoints { + for _, a := range ep.PermanentAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: ref.addrWithPrefix(), - }) } return addrs } -// PrimaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - +// primaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress - for proto, list := range n.mu.primary { - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints - // to avoid confusion and prevent the caller from using - // those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: ref.addrWithPrefix(), - }) + for p, ep := range n.networkEndpoints { + for _, a := range ep.PrimaryAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } return addrs @@ -950,147 +357,25 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - n.mu.RLock() - defer n.mu.RUnlock() - - list, ok := n.mu.primary[proto] + ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} } - var deprecatedEndpoint *referencedNetworkEndpoint - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints to avoid confusion - // and prevent the caller from using those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - if !ref.deprecated { - return ref.addrWithPrefix() - } - - if deprecatedEndpoint == nil { - deprecatedEndpoint = ref - } - } - - if deprecatedEndpoint != nil { - return deprecatedEndpoint.addrWithPrefix() - } - - return tcpip.AddressWithPrefix{} -} - -// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required -// by peb. -// -// n MUST be locked. -func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { - switch peb { - case CanBePrimaryEndpoint: - n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) - case FirstPrimaryEndpoint: - n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) - } -} - -func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := NetworkEndpointID{LocalAddress: r.address()} - - // Nothing to do if the reference has already been replaced with a different - // one. This happens in the case where 1) this endpoint's ref count hit zero - // and was waiting (on the lock) to be removed and 2) the same address was - // re-added in the meantime by removing this endpoint from the list and - // adding a new one. - if n.mu.endpoints[id] != r { - return - } - - if r.getKind() == permanent { - panic("Reference count dropped to zero before being removed") - } - - delete(n.mu.endpoints, id) - refs := n.mu.primary[r.protocol] - for i, ref := range refs { - if ref == r { - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - refs[len(refs)-1] = nil - break - } - } -} - -func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { - n.mu.Lock() - n.removeEndpointLocked(r) - n.mu.Unlock() -} - -func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadLocalAddress - } - - kind := r.getKind() - if kind != permanent && kind != permanentTentative { - return tcpip.ErrBadLocalAddress - } - - switch r.protocol { - case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) - default: - r.expireLocked() - return nil - } + return ep.MainAddress() } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { - addr := r.addrWithPrefix() - - isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) - - if isIPv6Unicast { - n.mu.ndp.stopDuplicateAddressDetection(addr.Address) - - // If we are removing an address generated via SLAAC, cleanup - // its SLAAC resources and notify the integrator. - switch r.configType { - case slaac: - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - case slaacTemp: - n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - } - } - - r.expireLocked() - - // At this point the endpoint is deleted. - - // If we are removing an IPv6 unicast address, leave the solicited-node - // multicast address. - // - // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node - // multicast group may be left by user action. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr.Address) - if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { +// removeAddress removes an address from n. +func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { + for _, ep := range n.networkEndpoints { + if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + continue + } else { return err } } - return nil -} - -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - return n.removePermanentAddressLocked(addr) + return tcpip.ErrBadLocalAddress } func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { @@ -1141,91 +426,66 @@ func (n *NIC) clearNeighbors() *tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.joinGroupLocked(protocol, addr) -} - -// joinGroupLocked adds a new endpoint for the given multicast address, if none -// exists yet. Otherwise it just increments its count. n MUST be locked before -// joinGroupLocked is called. -func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as // outlined in RFC 3810 section 5. - id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - if joins == 0 { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return tcpip.ErrUnknownProtocol - } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported } - n.mu.mcastJoins[id] = joins + 1 - return nil + + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + _, err := gep.JoinGroup(addr) + return err } // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.leaveGroupLocked(addr, false /* force */) -} +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported + } -// leaveGroupLocked decrements the count for the given multicast address, and -// when it reaches zero removes the endpoint for this address. n MUST be locked -// before leaveGroupLocked is called. -// -// If force is true, then the count for the multicast addres is ignored and the -// endpoint will be removed immediately. -func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { - id := NetworkEndpointID{addr} - joins, ok := n.mu.mcastJoins[id] + gep, ok := ep.(GroupAddressableEndpoint) if !ok { - // There are no joins with this address on this NIC. - return tcpip.ErrBadLocalAddress + return tcpip.ErrNotSupported } - joins-- - if force || joins == 0 { - // There are no outstanding joins or we are forced to leave, clean up. - delete(n.mu.mcastJoins, id) - return n.removePermanentAddressLocked(addr) + if _, err := gep.LeaveGroup(addr); err != nil { + return err } - n.mu.mcastJoins[id] = joins return nil } // isInGroup returns true if n has joined the multicast group addr. func (n *NIC) isInGroup(addr tcpip.Address) bool { - n.mu.RLock() - joins := n.mu.mcastJoins[NetworkEndpointID{addr}] - n.mu.RUnlock() + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + continue + } - return joins != 0 + if gep.IsInGroup(addr) { + return true + } + } + + return false } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) +func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { + r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - - ref.ep.HandlePacket(&r, pkt) - ref.decRef() + addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + addressEndpoint.DecRef() } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -1236,7 +496,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() - enabled := n.mu.enabled + enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { n.mu.RUnlock() @@ -1262,9 +522,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp local = n.linkEP.LinkAddress() } - // Are any packet sockets listening for this network protocol? + // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Add any other packet sockets that maybe listening for all protocols. + // Add any other packet type sockets that may be listening for all protocols. packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { @@ -1285,6 +545,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } if hasTransportHdr { + pkt.TransportProtocolNumber = transProtoNum // Parse the transport header if present. if state, ok := n.stack.transportProtocols[transProtoNum]; ok { state.proto.Parse(pkt) @@ -1293,29 +554,33 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return + if n.stack.handleLocal && !n.IsLoopback() { + if r := n.getAddress(protocol, src); r != nil { + r.DecRef() + + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.IsLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. + n.stack.stats.IP.IPTablesPreroutingDropped.Increment() return } } - if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) + if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { + n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) return } @@ -1323,7 +588,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1331,25 +596,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.ref.nic - n.mu.RLock() - ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() - n.mu.RUnlock() - if ok { - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, pkt) - ref.decRef() - r.Release() - return + n := r.nic + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { + if n.isValidForOutgoing(addressEndpoint) { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote + r.RemoteAddress = src + // TODO(b/123449044): Update the source NIC as well. + addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + addressEndpoint.DecRef() + r.Release() + return + } + + addressEndpoint.DecRef() } // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) @@ -1416,11 +682,11 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -1441,41 +707,47 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() - return + // We consider a malformed transport packet handled because there is + // nothing the caller can do. + return TransportPacketHandled } - } else { - // This is either a bad packet or was re-assembled from fragments. - transProto.Parse(pkt) + } else if !transProto.Parse(pkt) { + n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled } } - if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() - return + return TransportPacketHandled } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} if n.stack.demux.deliverPacket(r, protocol, pkt, id) { - return + return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { if state.defaultHandler(r, id, pkt) { - return + return TransportPacketHandled } } - // We could not find an appropriate destination for this packet, so - // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { + // We could not find an appropriate destination for this packet so + // give the protocol specific error handler a chance to handle it. + // If it doesn't handle it then we should do so. + switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled + case UnknownDestinationPacketUnhandled: + return TransportPacketDestinationPortUnreachable + case UnknownDestinationPacketHandled: + return TransportPacketHandled + default: + panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res)) } } @@ -1508,96 +780,23 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } } -// ID returns the identifier of n. +// ID implements NetworkInterface. func (n *NIC) ID() tcpip.NICID { return n.id } -// Name returns the name of n. +// Name implements NetworkInterface. func (n *NIC) Name() string { return n.name } -// Stack returns the instance of the Stack that owns this NIC. -func (n *NIC) Stack() *Stack { - return n.stack -} - -// LinkEndpoint returns the link endpoint of n. +// LinkEndpoint implements NetworkInterface. func (n *NIC) LinkEndpoint() LinkEndpoint { return n.linkEP } -// isAddrTentative returns true if addr is tentative on n. -// -// Note that if addr is not associated with n, then this function will return -// false. It will only return true if the address is associated with the NIC -// AND it is tentative. -func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - n.mu.RLock() - defer n.mu.RUnlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return false - } - - return ref.getKind() == permanentTentative -} - -// dupTentativeAddrDetected attempts to inform n that a tentative addr is a -// duplicate on a link. -// -// dupTentativeAddrDetected will remove the tentative address if it exists. If -// the address was generated via SLAAC, an attempt will be made to generate a -// new address. -func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadAddress - } - - if ref.getKind() != permanentTentative { - return tcpip.ErrInvalidEndpointState - } - - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a - // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { - return err - } - - prefix := ref.addrWithPrefix().Subnet() - - switch ref.configType { - case slaac: - n.mu.ndp.regenerateSLAACAddr(prefix) - case slaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) - } - - return nil -} - -// setNDPConfigs sets the NDP configurations for n. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (n *NIC) setNDPConfigs(c NDPConfigurations) { - c.validate() - - n.mu.Lock() - n.mu.ndp.configs = c - n.mu.Unlock() -} - -// NUDConfigs gets the NUD configurations for n. -func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { +// nudConfigs gets the NUD configurations for n. +func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { return NUDConfigurations{}, tcpip.ErrNotSupported } @@ -1617,49 +816,6 @@ func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { return nil } -// handleNDPRA handles an NDP Router Advertisement message that arrived on n. -func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.handleRA(ip, ra) -} - -type networkEndpointKind int32 - -const ( - // A permanentTentative endpoint is a permanent address that is not yet - // considered to be fully bound to an interface in the traditional - // sense. That is, the address is associated with a NIC, but packets - // destined to the address MUST NOT be accepted and MUST be silently - // dropped, and the address MUST NOT be used as a source address for - // outgoing packets. For IPv6, addresses will be of this kind until - // NDP's Duplicate Address Detection has resolved, or be deleted if - // the process results in detecting a duplicate address. - permanentTentative networkEndpointKind = iota - - // A permanent endpoint is created by adding a permanent address (vs. a - // temporary one) to the NIC. Its reference count is biased by 1 to avoid - // removal when no route holds a reference to it. It is removed by explicitly - // removing the permanent address from the NIC. - permanent - - // An expired permanent endpoint is a permanent endpoint that had its address - // removed from the NIC, and it is waiting to be removed once no more routes - // hold a reference to it. This is achieved by decreasing its reference count - // by 1. If its address is re-added before the endpoint is removed, its type - // changes back to permanent and its reference count increases by 1 again. - permanentExpired - - // A temporary endpoint is created for spoofing outgoing packets, or when in - // promiscuous mode and accepting incoming packets that don't match any - // permanent endpoint. Its reference count is not biased by 1 and the - // endpoint is removed immediately when no more route holds a reference to - // it. A temporary endpoint can be promoted to permanent if its address - // is added permanently. - temporary -) - func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1690,153 +846,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } -type networkEndpointConfigType int32 - -const ( - // A statically configured endpoint is an address that was added by - // some user-specified action (adding an explicit address, joining a - // multicast group). - static networkEndpointConfigType = iota - - // A SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4862 section 5.5.3. - slaac - - // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are - // not expected to be valid (or preferred) forever; hence the term temporary. - slaacTemp -) - -type referencedNetworkEndpoint struct { - ep NetworkEndpoint - addr tcpip.AddressWithPrefix - nic *NIC - protocol tcpip.NetworkProtocolNumber - - // linkCache is set if link address resolution is enabled for this - // protocol. Set to nil otherwise. - linkCache LinkAddressCache - - // linkRes is set if link address resolution is enabled for this protocol. - // Set to nil otherwise. - linkRes LinkAddressResolver - - // refs is counting references held for this endpoint. When refs hits zero it - // triggers the automatic removal of the endpoint from the NIC. - refs int32 - - // networkEndpointKind must only be accessed using {get,set}Kind(). - kind networkEndpointKind - - // configType is the method that was used to configure this endpoint. - // This must never change except during endpoint creation and promotion to - // permanent. - configType networkEndpointConfigType - - // deprecated indicates whether or not the endpoint should be considered - // deprecated. That is, when deprecated is true, other endpoints that are not - // deprecated should be preferred. - deprecated bool -} - -func (r *referencedNetworkEndpoint) address() tcpip.Address { - return r.addr.Address -} - -func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return r.addr -} - -func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { - return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) -} - -func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { - atomic.StoreInt32((*int32)(&r.kind), int32(kind)) -} - // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address) // has been removed) unless the NIC is in spoofing mode, or temporary. -func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - r.nic.mu.RLock() - defer r.nic.mu.RUnlock() - - return r.isValidForOutgoingRLocked() -} - -// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires -// r.nic.mu to be read locked. -func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - if !r.nic.mu.enabled { - return false - } - - return r.isAssignedRLocked(r.nic.mu.spoofing) -} - -// isAssignedRLocked returns true if r is considered to be assigned to the NIC. -// -// r.nic.mu must be read locked. -func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { - switch r.getKind() { - case permanentTentative: - return false - case permanentExpired: - return spoofingOrPromiscuous - default: - return true - } -} - -// expireLocked decrements the reference count and marks the permanent endpoint -// as expired. -func (r *referencedNetworkEndpoint) expireLocked() { - r.setKind(permanentExpired) - r.decRefLocked() -} - -// decRef decrements the ref count and cleans up the endpoint once it reaches -// zero. -func (r *referencedNetworkEndpoint) decRef() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpoint(r) - } -} - -// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpointLocked(r) - } -} - -// incRef increments the ref count. It must only be called when the caller is -// known to be holding a reference to the endpoint, otherwise tryIncRef should -// be used. -func (r *referencedNetworkEndpoint) incRef() { - atomic.AddInt32(&r.refs, 1) -} - -// tryIncRef attempts to increment the ref count from n to n+1, but only if n is -// not zero. That is, it will increment the count if the endpoint is still -// alive, and do nothing if it has already been clean up. -func (r *referencedNetworkEndpoint) tryIncRef() bool { - for { - v := atomic.LoadInt32(&r.refs) - if v == 0 { - return false - } - - if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { - return true - } - } -} - -// stack returns the Stack instance that owns the underlying endpoint. -func (r *referencedNetworkEndpoint) stack() *Stack { - return r.nic.stack +func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + return n.Enabled() && ep.IsAssigned(spoofing) } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index dd6474297..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,96 +15,40 @@ package stack import ( - "math" "testing" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var _ LinkEndpoint = (*testLinkEndpoint)(nil) +var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) +var _ NDPEndpoint = (*testIPv6Endpoint)(nil) -// A LinkEndpoint that throws away outgoing packets. +// An IPv6 NetworkEndpoint that throws away outgoing packets. // -// We use this instead of the channel endpoint as the channel package depends on +// We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. -type testLinkEndpoint struct { - dispatcher NetworkDispatcher -} - -// Attach implements LinkEndpoint.Attach. -func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements LinkEndpoint.IsAttached. -func (e *testLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements LinkEndpoint.MTU. -func (*testLinkEndpoint) MTU() uint32 { - return math.MaxUint16 -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { - return CapabilityResolutionRequired -} +type testIPv6Endpoint struct { + AddressableEndpointState -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*testLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol -// LinkAddress returns the link address of this endpoint. -func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" + invalidatedRtr tcpip.Address } -// Wait implements LinkEndpoint.Wait. -func (*testLinkEndpoint) Wait() {} - -// WritePacket implements LinkEndpoint.WritePacket. -func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) Enable() *tcpip.Error { return nil } -// WritePackets implements LinkEndpoint.WritePackets. -func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - panic("not implemented") +func (*testIPv6Endpoint) Enabled() bool { + return true } -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - nicID tcpip.NICID - linkEP LinkEndpoint - protocol *testIPv6Protocol -} +func (*testIPv6Endpoint) Disable() {} // DefaultTTL implements NetworkEndpoint.DefaultTTL. func (*testIPv6Endpoint) DefaultTTL() uint8 { @@ -116,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 { return e.linkEP.MTU() - header.IPv6MinimumSize } -// Capabilities implements NetworkEndpoint.Capabilities. -func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize @@ -144,23 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// NICID implements NetworkEndpoint.NICID. -func (e *testIPv6Endpoint) NICID() tcpip.NICID { - return e.nicID -} - // HandlePacket implements NetworkEndpoint.HandlePacket. func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { } // Close implements NetworkEndpoint.Close. -func (*testIPv6Endpoint) Close() {} +func (e *testIPv6Endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} // NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } +func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.invalidatedRtr = rtr +} + var _ NetworkProtocol = (*testIPv6Protocol)(nil) // An IPv6 NetworkProtocol that supports the bare minimum to make a stack @@ -192,12 +132,14 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { - return &testIPv6Endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { + e := &testIPv6Endpoint{ + nicID: nic.ID(), + linkEP: nic.LinkEndpoint(), protocol: p, } + e.AddressableEndpointState.Init(e) + return e } // SetOption implements NetworkProtocol.SetOption. @@ -241,38 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } -// Test the race condition where a NIC is removed and an RS timer fires at the -// same time. -func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { - const ( - nicID = 1 - - maxRtrSolicitations = 5 - ) - - e := testLinkEndpoint{} - s := New(Options{ - NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, - NDPConfigs: NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: minimumRtrSolicitationInterval, - }, - }) - - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) - } - - s.mu.Lock() - // Wait for the router solicitation timer to fire and block trying to obtain - // the stack lock when doing link address resolution. - time.Sleep(minimumRtrSolicitationInterval * 2) - if err := s.removeNICLocked(nicID); err != nil { - t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) - } - s.mu.Unlock() -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 2b97e5972..8cffb9fc6 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -60,7 +60,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, UseNeighborCache: true, }) @@ -137,7 +137,7 @@ func TestDefaultNUDConfigurations(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), UseNeighborCache: true, }) @@ -192,7 +192,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -249,7 +249,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -329,7 +329,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -391,7 +391,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -443,7 +443,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -495,7 +495,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -547,7 +547,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -599,7 +599,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 17b8beebb..105583c49 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) type headerType int @@ -80,11 +81,17 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocol is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() + // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber + // TransportProtocol is only valid if it is non zero. + // TODO(gvisor.dev/issue/3810): This and the network protocol number should + // be moved into the headerinfo. This should resolve the validity issue. + TransportProtocolNumber tcpip.TransportProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -234,20 +241,35 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, } return newPk } +// Network returns the network header as a header.Network. +// +// Network should only be called when NetworkHeader has been set. +func (pk *PacketBuffer) Network() header.Network { + switch netProto := pk.NetworkProtocolNumber; netProto { + case header.IPv4ProtocolNumber: + return header.IPv4(pk.NetworkHeader().View()) + case header.IPv6ProtocolNumber: + return header.IPv6(pk.NetworkHeader().View()) + default: + panic(fmt.Sprintf("unknown network protocol number %d", netProto)) + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2d88fa1f7..16f854e1f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -125,6 +127,26 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } +// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// HandleUnknownDestinationPacket(). +type UnknownDestinationPacketDisposition int + +const ( + // UnknownDestinationPacketMalformed denotes that the packet was malformed + // and no further processing should be attempted other than updating + // statistics. + UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota + + // UnknownDestinationPacketUnhandled tells the caller that the packet was + // well formed but that the issue was not handled and the stack should take + // the default action. + UnknownDestinationPacketUnhandled + + // UnknownDestinationPacketHandled tells the caller that it should do + // no further processing. + UnknownDestinationPacketHandled +) + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { @@ -132,10 +154,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -147,24 +169,22 @@ type TransportProtocol interface { ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this - // protocol but that don't match any existing endpoint. For example, - // it is targeted at a port that have no listeners. - // - // The return value indicates whether the packet was well-formed (for - // stats purposes only). + // protocol that don't match any existing endpoint. For example, + // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // the issue. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -179,6 +199,25 @@ type TransportProtocol interface { Parse(pkt *PacketBuffer) (ok bool) } +// TransportPacketDisposition is the result from attempting to deliver a packet +// to the transport layer. +type TransportPacketDisposition int + +const ( + // TransportPacketHandled indicates that a transport packet was handled by the + // transport layer and callers need not take any further action. + TransportPacketHandled TransportPacketDisposition = iota + + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + + // TransportPacketDestinationPortUnreachable indicates that there weren't any + // listeners interested in the packet and the transport protocol has no means + // to notify the sender. + TransportPacketDestinationPortUnreachable +) + // TransportDispatcher contains the methods used by the network stack to deliver // packets to the appropriate transport endpoint after it has been handled by // the network layer. @@ -189,7 +228,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -226,9 +265,257 @@ type NetworkHeaderParams struct { TOS uint8 } +// GroupAddressableEndpoint is an endpoint that supports group addressing. +// +// An endpoint is considered to support group addressing when one or more +// endpoints may associate themselves with the same identifier (group address). +type GroupAddressableEndpoint interface { + // JoinGroup joins the spcified group. + // + // Returns true if the group was newly joined. + JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + + // LeaveGroup attempts to leave the specified group. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. + LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + + // IsInGroup returns true if the endpoint is a member of the specified group. + IsInGroup(group tcpip.Address) bool +} + +// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary +// behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, they are ordered by recency. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + +// AddressConfigType is the method used to add an address. +type AddressConfigType int + +const ( + // AddressConfigStatic is a statically configured address endpoint that was + // added by some user-specified action (adding an explicit address, joining a + // multicast group). + AddressConfigStatic AddressConfigType = iota + + // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862 + // section 5.5.3. + AddressConfigSlaac + + // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as + // per RFC 4941. Temporary SLAAC addresses are short-lived and are not + // to be valid (or preferred) forever; hence the term temporary. + AddressConfigSlaacTemp +) + +// AssignableAddressEndpoint is a reference counted address endpoint that may be +// assigned to a NetworkEndpoint. +type AssignableAddressEndpoint interface { + // NetworkEndpoint returns the NetworkEndpoint the receiver is associated + // with. + NetworkEndpoint() NetworkEndpoint + + // AddressWithPrefix returns the endpoint's address. + AddressWithPrefix() tcpip.AddressWithPrefix + + // IsAssigned returns whether or not the endpoint is considered bound + // to its NetworkEndpoint. + IsAssigned(allowExpired bool) bool + + // IncRef increments this endpoint's reference count. + // + // Returns true if it was successfully incremented. If it returns false, then + // the endpoint is considered expired and should no longer be used. + IncRef() bool + + // DecRef decrements this endpoint's reference count. + DecRef() +} + +// AddressEndpoint is an endpoint representing an address assigned to an +// AddressableEndpoint. +type AddressEndpoint interface { + AssignableAddressEndpoint + + // GetKind returns the address kind for this endpoint. + GetKind() AddressKind + + // SetKind sets the address kind for this endpoint. + SetKind(AddressKind) + + // ConfigType returns the method used to add the address. + ConfigType() AddressConfigType + + // Deprecated returns whether or not this endpoint is deprecated. + Deprecated() bool + + // SetDeprecated sets this endpoint's deprecated status. + SetDeprecated(bool) +} + +// AddressKind is the kind of of an address. +// +// See the values of AddressKind for more details. +type AddressKind int + +const ( + // PermanentTentative is a permanent address endpoint that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses are of this kind until NDP's + // Duplicate Address Detection (DAD) resolves. If DAD fails, the address + // is removed. + PermanentTentative AddressKind = iota + + // Permanent is a permanent endpoint (vs. a temporary one) assigned to the + // NIC. Its reference count is biased by 1 to avoid removal when no route + // holds a reference to it. It is removed by explicitly removing the address + // from the NIC. + Permanent + + // PermanentExpired is a permanent endpoint that had its address removed from + // the NIC, and it is waiting to be removed once no references to it are held. + // + // If the address is re-added before the endpoint is removed, its type + // changes back to Permanent. + PermanentExpired + + // Temporary is an endpoint, created on a one-off basis to temporarily + // consider the NIC bound an an address that it is not explictiy bound to + // (such as a permanent address). Its reference count must not be biased by 1 + // so that the address is removed immediately when references to it are no + // longer held. + // + // A temporary endpoint may be promoted to permanent if the address is added + // permanently. + Temporary +) + +// IsPermanent returns true if the AddressKind represents a permanent address. +func (k AddressKind) IsPermanent() bool { + switch k { + case Permanent, PermanentTentative: + return true + case Temporary, PermanentExpired: + return false + default: + panic(fmt.Sprintf("unrecognized address kind = %d", k)) + } +} + +// AddressableEndpoint is an endpoint that supports addressing. +// +// An endpoint is considered to support addressing when the endpoint may +// associate itself with an identifier (address). +type AddressableEndpoint interface { + // AddAndAcquirePermanentAddress adds the passed permanent address. + // + // Returns tcpip.ErrDuplicateAddress if the address exists. + // + // Acquires and returns the AddressEndpoint for the added address. + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + + // RemovePermanentAddress removes the passed address if it is a permanent + // address. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // permanent address. + RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + + // MainAddress returns the endpoint's primary permanent address. + MainAddress() tcpip.AddressWithPrefix + + // AcquireAssignedAddress returns an address endpoint for the passed address + // that is considered bound to the endpoint, optionally creating a temporary + // endpoint if requested and no existing address exists. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if the specified address is not local to this endpoint. + AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint + + // AcquireOutgoingPrimaryAddress returns a primary address that may be used as + // a source address when sending packets to the passed remote address. + // + // If allowExpired is true, expired addresses may be returned. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if a primary address is not available. + AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint + + // PrimaryAddresses returns the primary addresses. + PrimaryAddresses() []tcpip.AddressWithPrefix + + // PermanentAddresses returns all the permanent addresses. + PermanentAddresses() []tcpip.AddressWithPrefix +} + +// NDPEndpoint is a network endpoint that supports NDP. +type NDPEndpoint interface { + NetworkEndpoint + + // InvalidateDefaultRouter invalidates a default router discovered through + // NDP. + InvalidateDefaultRouter(tcpip.Address) +} + +// NetworkInterface is a network interface. +type NetworkInterface interface { + // ID returns the interface's ID. + ID() tcpip.NICID + + // IsLoopback returns true if the interface is a loopback interface. + IsLoopback() bool + + // Name returns the name of the interface. + // + // May return an empty string if the interface is not configured with a name. + Name() string + + // Enabled returns true if the interface is enabled. + Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + AddressableEndpoint + + // Enable enables the endpoint. + // + // Must only be called when the stack is in a state that allows the endpoint + // to send and receive packets. + // + // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() *tcpip.Error + + // Enabled returns true if the endpoint is enabled. + Enabled() bool + + // Disable disables the endpoint. + Disable() + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) // for this endpoint. DefaultTTL() uint8 @@ -238,10 +525,6 @@ type NetworkEndpoint interface { // minus the network endpoint max header length. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // underlying link-layer endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the network (and lower // level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -262,9 +545,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // NICID returns the id of the NIC this endpoint belongs to. - NICID() tcpip.NICID - // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // @@ -279,6 +559,17 @@ type NetworkEndpoint interface { NetworkProtocolNumber() tcpip.NetworkProtocolNumber } +// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. +type ForwardingNetworkProtocol interface { + NetworkProtocol + + // Forwarding returns the forwarding configuration. + Forwarding() bool + + // SetForwarding sets the forwarding configuration. + SetForwarding(bool) +} + // NetworkProtocol is the interface that needs to be implemented by network // protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. type NetworkProtocol interface { @@ -298,7 +589,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -427,8 +718,8 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // - // Attach will be called with a nil dispatcher if the receiver's associated - // NIC is being removed. + // Attach is called with a nil dispatcher when the endpoint's NIC is being + // removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index c2eabde9e..effe30155 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -42,21 +42,27 @@ type Route struct { // NetProto is the network-layer protocol. NetProto tcpip.NetworkProtocolNumber - // ref a reference to the network endpoint through which the route - // starts. - ref *referencedNetworkEndpoint - // Loop controls where WritePacket should send packets. Loop PacketLooping - // directedBroadcast indicates whether this route is sending a directed - // broadcast packet. - directedBroadcast bool + // nic is the NIC the route goes through. + nic *NIC + + // addressEndpoint is the local address this route is associated with. + addressEndpoint AssignableAddressEndpoint + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver } // makeRoute initializes a new route. It takes ownership of the provided -// reference to a network endpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route { +// AssignableAddressEndpoint. +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { loop := PacketOut if handleLocal && localAddr != "" && remoteAddr == localAddr { loop = PacketLoop @@ -66,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - return Route{ + linkEP := nic.LinkEndpoint() + r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: localLinkAddr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, - ref: ref, + addressEndpoint: addressEndpoint, + nic: nic, Loop: loop, } + + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + r.linkRes = linkRes + r.linkCache = nic.stack + } + } + + return r } // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.ref.ep.NICID() + return r.nic.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.ref.ep.MaxHeaderLength() + return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.ref.nic.stack.Stats() + return r.nic.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -99,12 +116,12 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.ref.ep.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.ref.ep.(GSOEndpoint); ok { + if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -142,8 +159,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if r.ref.nic.neigh != nil { - entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker) + if neigh := r.nic.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) if err != nil { return ch, err } @@ -151,7 +168,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { return ch, err } @@ -166,12 +183,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if r.ref.nic.neigh != nil { - r.ref.nic.neigh.removeWaker(nextAddr, waker) + if neigh := r.nic.neigh; neigh != nil { + neigh.removeWaker(nextAddr, waker) return } - r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) } // IsResolutionRequired returns true if Resolve() must be called to resolve @@ -179,104 +196,94 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.ref.nic.neigh != nil { - return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == "" + if r.nic.neigh != nil { + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" } - return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() - err := r.ref.ep.WritePacket(r, gso, params, pkt) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + if err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt); err != nil { + return err } - return err + + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return 0, tcpip.ErrInvalidEndpointState } - // WritePackets takes ownership of pkt, calculate length first. - numPkts := pkts.Len() - - n, err := r.ref.ep.WritePackets(r, gso, pkts, params) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) - } - r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - + n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Data.Size() - if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil { return err } - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.ref.ep.DefaultTTL() + return r.addressEndpoint.NetworkEndpoint().DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.ref.ep.MTU() + return r.addressEndpoint.NetworkEndpoint().MTU() } // NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying // network endpoint. func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.ref.ep.NetworkProtocolNumber() + return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.ref != nil { - r.ref.decRef() - r.ref = nil + if r.addressEndpoint != nil { + r.addressEndpoint.DecRef() + r.addressEndpoint = nil } } -// Clone Clone a route such that the original one can be released and the new -// one will remain valid. +// Clone clones the route. func (r *Route) Clone() Route { - if r.ref != nil { - r.ref.incRef() + if r.addressEndpoint != nil { + _ = r.addressEndpoint.IncRef() } return *r } @@ -300,27 +307,30 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.ref.stack() + return r.nic.stack +} + +func (r *Route) isV4Broadcast(addr tcpip.Address) bool { + if addr == header.IPv4Broadcast { + return true + } + + subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + return subnet.IsBroadcast(addr) } // IsOutboundBroadcast returns true if the route is for an outbound broadcast // packet. func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast + return r.isV4Broadcast(r.RemoteAddress) } // IsInboundBroadcast returns true if the route is for an inbound broadcast // packet. func (r *Route) IsInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - if r.LocalAddress == header.IPv4Broadcast { - return true - } - - addr := r.ref.addrWithPrefix() - subnet := addr.Subnet() - return subnet.IsBroadcast(r.LocalAddress) + return r.isV4Broadcast(r.LocalAddress) } // ReverseRoute returns new route with given source and destination address. @@ -331,7 +341,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { LocalLinkAddress: r.RemoteLinkAddress, RemoteAddress: src, RemoteLinkAddress: r.LocalLinkAddress, - ref: r.ref, Loop: r.Loop, + addressEndpoint: r.addressEndpoint, + nic: r.nic, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c86ee1c13..57d8e79e0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -144,10 +144,7 @@ type TCPReceiverState struct { // PendingBufUsed is the number of bytes pending in the receive // queue. - PendingBufUsed seqnum.Size - - // PendingBufSize is the size of the socket receive buffer. - PendingBufSize seqnum.Size + PendingBufUsed int } // TCPSenderState holds a copy of the internal state of the sender for @@ -366,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } -// NICNameFromID is a function that returns a stable name for the specified NIC, -// even if different NIC IDs are used to refer to the same NIC in different -// program runs. It is used when generating opaque interface identifiers (IIDs). -// If the NIC was created with a name, it will be passed to NICNameFromID. -// -// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. -type NICNameFromID func(tcpip.NICID, string) string - -// OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. -type OpaqueInterfaceIdentifierOptions struct { - // NICNameFromID is a function that returns a stable name for a specified NIC, - // even if the NIC ID changes over time. - // - // Must be specified to generate the opaque IID. - NICNameFromID NICNameFromID - - // SecretKey is a pseudo-random number used as the secret key when generating - // opaque IIDs as defined by RFC 7217. The key SHOULD be at least - // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness - // requirements for security as outlined by RFC 4086. SecretKey MUST NOT - // change between program runs, unless explicitly changed. - // - // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey - // MUST NOT be modified after Stack is created. - // - // May be nil, but a nil value is highly discouraged to maintain - // some level of randomness between nodes. - SecretKey []byte -} - // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -415,10 +380,12 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool - cleanupEndpoints map[TransportEndpoint]struct{} + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + + // cleanupEndpointsMu protects cleanupEndpoints. + cleanupEndpointsMu sync.Mutex + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -429,7 +396,7 @@ type Stack struct { // If not nil, then any new endpoints will have this probe function // invoked everytime they receive a TCP segment. - tcpProbeFunc TCPProbeFunc + tcpProbeFunc atomic.Value // TCPProbeFunc // clock is used to generate user-visible times. clock tcpip.Clock @@ -455,9 +422,6 @@ type Stack struct { // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 - // ndpConfigs is the default NDP configurations used by interfaces. - ndpConfigs NDPConfigurations - // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations @@ -465,15 +429,6 @@ type Stack struct { // by the NIC's neighborCache instead of linkAddrCache. useNeighborCache bool - // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. nudDisp NUDDispatcher @@ -481,14 +436,6 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -511,13 +458,25 @@ type UniqueID interface { UniqueID() uint64 } +// NetworkProtocolFactory instantiates a network protocol. +// +// NetworkProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type NetworkProtocolFactory func(*Stack) NetworkProtocol + +// TransportProtocolFactory instantiates a transport protocol. +// +// TransportProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type TransportProtocolFactory func(*Stack) TransportProtocol + // Options contains optional Stack configuration. type Options struct { // NetworkProtocols lists the network protocols to enable. - NetworkProtocols []NetworkProtocol + NetworkProtocols []NetworkProtocolFactory // TransportProtocols lists the transport protocols to enable. - TransportProtocols []TransportProtocol + TransportProtocols []TransportProtocolFactory // Clock is an optional clock source used for timestampping packets. // @@ -535,13 +494,6 @@ type Options struct { // UniqueID is an optional generator of unique identifiers. UniqueID UniqueID - // NDPConfigs is the default NDP configurations used by interfaces. - // - // By default, NDPConfigs will have a zero value for its - // DupAddrDetectTransmits field, implying that DAD will not be performed - // before assigning an address to a NIC. - NDPConfigs NDPConfigurations - // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations @@ -552,24 +504,6 @@ type Options struct { // and ClearNeighbors. UseNeighborCache bool - // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. - // - // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address Detection - // is enabled, an address will only be assigned if it successfully resolves. - // If it fails, no further attempt will be made to auto-generate an IPv6 - // link-local address. - // - // The generated link-local address will follow RFC 4291 Appendix A - // guidelines. - AutoGenIPv6LinkLocal bool - - // NDPDisp is the NDP event dispatcher that an integrator can provide to - // receive NDP related events. - NDPDisp NDPDispatcher - // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -578,31 +512,12 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface - // identifiers (IIDs) as outlined by RFC 7217. - OpaqueIIDOpts OpaqueInterfaceIdentifierOptions - // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by rand.Read(). // // RandSource must be thread-safe. RandSource mathrand.Source - - // TempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable - // and random from the perspective of other nodes on the network. It is - // recommended that the seed be a random byte buffer of at least - // header.IIDSize bytes to make sure that temporary SLAAC addresses are - // sufficiently random. It should follow minimum randomness requirements for - // security as outlined by RFC 4086. - // - // Note: using a nil value, the same seed across netstack program runs, or a - // seed that is too small would reduce randomness and increase predictability, - // defeating the purpose of temporary SLAAC addresses. - TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -705,36 +620,28 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } - // Make sure opts.NDPConfigs contains valid values only. - opts.NDPConfigs.validate() - opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: DefaultTables(), - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, - nudConfigs: opts.NUDConfigs, - useNeighborCache: opts.UseNeighborCache, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, - uniqueIDGenerator: opts.UniqueID, - ndpDisp: opts.NDPDisp, - nudDisp: opts.NUDDisp, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - forwarder: newForwardQueue(), - randomGenerator: mathrand.New(randSrc), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: DefaultTables(), + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + useNeighborCache: opts.UseNeighborCache, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -748,7 +655,8 @@ func New(opts Options) *Stack { } // Add specified network protocols. - for _, netProto := range opts.NetworkProtocols { + for _, netProtoFactory := range opts.NetworkProtocols { + netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -756,7 +664,8 @@ func New(opts Options) *Stack { } // Add specified transport protocols. - for _, transProto := range opts.TransportProtocols { + for _, transProtoFactory := range opts.TransportProtocols { + transProto := transProtoFactory(s) s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } @@ -814,7 +723,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -829,7 +738,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -863,46 +772,37 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs for the +// passed protocol. +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return tcpip.ErrUnknownProtocol + } - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { - return + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return tcpip.ErrNotSupported } - s.forwarding = enable + forwardingProtocol.SetForwarding(enable) + return nil +} - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { - return +// Forwarding returns true if packet forwarding between NICs is enabled for the +// passed protocol. +func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return false } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() - } + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return false } -} -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding + return forwardingProtocol.Forwarding() } // SetRouteTable assigns the route table to be used by this stack. It @@ -937,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(network, waiterQueue) } // NewRawEndpoint creates a new raw transport layer endpoint of the given @@ -957,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewRawEndpoint(s, network, waiterQueue) + return t.proto.NewRawEndpoint(network, waiterQueue) } // NewPacketEndpoint creates a new packet endpoint listening for the given @@ -1064,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - return nic.disable() + nic.disable() + return nil } // CheckNIC checks if a NIC is usable. @@ -1077,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } - return nic.enabled() + return nic.Enabled() } // RemoveNIC removes NIC and all related routes from the network stack. @@ -1155,14 +1056,14 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.enabled(), + Running: nic.Enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.isLoopback(), + Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.PrimaryAddresses(), + ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -1226,7 +1127,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, peb) } // RemoveAddress removes an existing network-layer address from the specified @@ -1236,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + return nic.removeAddress(addr) } return tcpip.ErrUnknownNICID @@ -1250,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) for id, nic := range s.nics { - nics[id] = nic.AllAddresses() + nics[id] = nic.allPermanentAddresses() } return nics } @@ -1272,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1289,9 +1190,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + if nic, ok := s.nics[id]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil } } } else { @@ -1299,22 +1200,20 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { + if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.address() + remoteAddr = addressEndpoint.AddressWithPrefix().Address } - r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) - + r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) if len(route.Gateway) > 0 { if needRoute { r.NextHop = route.Gateway } - } else if r.directedBroadcast { + } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -1352,21 +1251,20 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto return 0 } - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref == nil { + addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if addressEndpoint == nil { return 0 } - ref.decRef() + addressEndpoint.DecRef() return nic.id } // Go through all the NICs. for _, nic := range s.nics { - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref != nil { - ref.decRef() + if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() return nic.id } } @@ -1528,10 +1426,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { - s.mu.Lock() - defer s.mu.Unlock() - + s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} + s.cleanupEndpointsMu.Unlock() s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } @@ -1539,9 +1436,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp // CompleteTransportEndpointCleanup removes the endpoint from the cleanup // stage. func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() delete(s.cleanupEndpoints, ep) - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // FindTransportEndpoint finds an endpoint that most closely matches the provided @@ -1584,23 +1481,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint { // CleanupEndpoints returns endpoints currently in the cleanup state. func (s *Stack) CleanupEndpoints() []TransportEndpoint { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) for e := range s.cleanupEndpoints { es = append(es, e) } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() return es } // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() for _, e := range es { s.cleanupEndpoints[e] = struct{}{} } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // Close closes all currently registered transport endpoints. @@ -1795,18 +1692,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // guarantee provided on which probe will be invoked. Ideally this should only // be called once per stack. func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { - s.mu.Lock() - s.tcpProbeFunc = probe - s.mu.Unlock() + s.tcpProbeFunc.Store(probe) } // GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil // otherwise. func (s *Stack) GetTCPProbe() TCPProbeFunc { - s.mu.Lock() - p := s.tcpProbeFunc - s.mu.Unlock() - return p + p := s.tcpProbeFunc.Load() + if p == nil { + return nil + } + return p.(TCPProbeFunc) } // RemoveTCPProbe removes an installed TCP probe. @@ -1815,9 +1711,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc { // have a probe attached. Endpoints already created will continue to invoke // TCP probe. func (s *Stack) RemoveTCPProbe() { - s.mu.Lock() - s.tcpProbeFunc = nil - s.mu.Unlock() + // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics. + s.tcpProbeFunc.Store(TCPProbeFunc(nil)) } // JoinGroup joins the given multicast group on the given NIC. @@ -1838,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { - return nic.leaveGroup(multicastAddr) + return nic.leaveGroup(protocol, multicastAddr) } return tcpip.ErrUnknownNICID } @@ -1890,53 +1785,18 @@ func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } -// IsAddrTentative returns true if addr is tentative on the NIC with ID id. -// -// Note that if addr is not associated with a NIC with id ID, then this -// function will return false. It will only return true if the address is -// associated with the NIC AND it is tentative. -func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[id] - if !ok { - return false, tcpip.ErrUnknownNICID - } - - return nic.isAddrTentative(addr), nil -} - -// DupTentativeAddrDetected attempts to inform the NIC with ID id that a -// tentative addr on it is a duplicate on a link. -func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol +// number installed on the specified NIC. +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - return nic.dupTentativeAddrDetected(addr) -} - -// SetNDPConfigurations sets the per-interface NDP configurations on the NIC -// with ID id to c. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] + nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return nil, tcpip.ErrUnknownNICID } - nic.setNDPConfigs(c) - return nil + return nic.networkEndpoints[proto], nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1949,7 +1809,7 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err return NUDConfigurations{}, tcpip.ErrUnknownNICID } - return nic.NUDConfigs() + return nic.nudConfigs() } // SetNUDConfigurations sets the per-interface NUD configurations. @@ -1968,22 +1828,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip return nic.setNUDConfigs(c) } -// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement -// message that it needs to handle. -func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - nic.handleNDPRA(ip, ra) - - return nil -} - // Seed returns a 32 bit value that can be used as a seed value for port // picking, ISN generation etc. // @@ -2025,16 +1869,14 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres defer s.mu.RUnlock() for _, nic := range s.nics { - id := NetworkEndpointID{address} - - if ref, ok := nic.mu.endpoints[id]; ok { - nic.mu.RLock() - defer nic.mu.RUnlock() - - // An endpoint with this id exists, check if it can be - // used and return it. - return ref.ep, nil + addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + continue } + + ep := addressEndpoint.NetworkEndpoint() + addressEndpoint.DecRef() + return ep, nil } return nil, tcpip.ErrBadAddress } @@ -2051,3 +1893,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { return nic.Name() } + +// NewJob returns a new tcpip.Job using the stack's clock. +func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 60b54c244..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -68,18 +69,41 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { + stack.AddressableEndpointState + + mu struct { + sync.RWMutex + + enabled bool + } + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = true + return nil } -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (f *fakeNetworkEndpoint) Enabled() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.enabled +} + +func (f *fakeNetworkEndpoint) Disable() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = false +} + +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -118,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto return 0 } -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -156,7 +176,9 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack return tcpip.ErrNotSupported } -func (*fakeNetworkEndpoint) Close() {} +func (f *fakeNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index @@ -165,6 +187,11 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 + + mu struct { + sync.RWMutex + forwarding bool + } } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -187,13 +214,15 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { - return &fakeNetworkEndpoint{ - nicID: nicID, +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &fakeNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } + e.AddressableEndpointState.Init(e) + return e } func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { @@ -216,13 +245,13 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) } } -// Close implements TransportProtocol.Close. +// Close implements NetworkProtocol.Close. func (*fakeNetworkProtocol) Close() {} -// Wait implements TransportProtocol.Wait. +// Wait implements NetworkProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} -// Parse implements TransportProtocol.Parse. +// Parse implements NetworkProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { @@ -231,7 +260,21 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -func fakeNetFactory() stack.NetworkProtocol { +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + +func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -268,7 +311,7 @@ func TestNetworkReceive(t *testing.T) { // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -428,7 +471,7 @@ func TestNetworkSend(t *testing.T) { // existing nic. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) @@ -455,7 +498,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -555,7 +598,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -574,7 +617,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { @@ -586,7 +629,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := loopback.New() @@ -633,7 +676,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { @@ -645,7 +688,7 @@ func TestRemoveNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -706,7 +749,7 @@ func TestRouteWithDownNIC(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(1, defaultMTU, "") @@ -872,7 +915,7 @@ func TestRoutes(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -952,7 +995,7 @@ func TestAddressRemoval(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -999,7 +1042,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1090,7 +1133,7 @@ func TestEndpointExpiration(t *testing.T) { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1248,7 +1291,7 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1300,7 +1343,7 @@ func TestSpoofingWithAddress(t *testing.T) { dstAddr := tcpip.Address("\x03") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1366,7 +1409,7 @@ func TestSpoofingNoAddress(t *testing.T) { dstAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1429,7 +1472,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error { func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1472,7 +1515,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // Create a new stack with two NICs. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1573,7 +1616,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1630,8 +1673,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{}, }) opt := tcpip.DefaultTTLOption(5) @@ -1657,7 +1700,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1724,7 +1767,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1809,7 +1852,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1836,7 +1879,7 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1870,7 +1913,7 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1901,7 +1944,7 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2022,7 +2065,7 @@ func TestCreateNICWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2089,9 +2132,9 @@ func TestNICForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { @@ -2213,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName string autoGen bool linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions + iidOpts ipv6.OpaqueInterfaceIdentifierOptions shouldGen bool expectedAddr tcpip.Address }{ @@ -2229,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2274,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: true, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2286,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { { name: "OIID Empty MAC and empty nicName", autoGen: true, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:1], }, @@ -2298,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test", autoGen: true, linkAddr: "\x01\x02\x03", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:2], }, @@ -2310,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test2", autoGen: true, linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:3], }, @@ -2322,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test3", autoGen: true, linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, }, shouldGen: true, @@ -2336,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, + })}, } e := channel.New(0, 1280, test.linkAddr) @@ -2411,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { tests := []struct { name string - opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions }{ { name: "IID From MAC", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, }, { name: "Opaque IID", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName }, @@ -2430,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + })}, } e := loopback.New() @@ -2461,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := stack.DefaultNDPConfigurations() + ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + })}, } e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2522,7 +2568,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { for _, ps := range pebs { t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2813,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2869,7 +2916,7 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -2921,7 +2968,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -2982,7 +3029,7 @@ func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -3059,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + })}, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3423,7 +3471,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, ep); err != nil { @@ -3461,7 +3509,7 @@ func TestResolveWith(t *testing.T) { ) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -3499,3 +3547,131 @@ func TestResolveWith(t *testing.T) { t.Fatal("got r.IsResolutionRequired() = true, want = false") } } + +// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its +// associated address is removed should not cause a panic. +func TestRouteReleaseAfterAddrRemoval(t *testing.T) { + const ( + nicID = 1 + localAddr = tcpip.Address("\x01") + remoteAddr = tcpip.Address("\x02") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) + } + // Should not panic. + defer r.Release() + + // Check that removing the same address fails. + if err := s.RemoveAddress(nicID, localAddr); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) + } +} + +func TestGetNetworkEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) + for _, test := range tests { + factories = append(factories, test.protoFactory) + } + + s := stack.New(stack.Options{ + NetworkProtocols: factories, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) + } + + if got := ep.NetworkProtocolNumber(); got != test.protoNum { + t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) + } + }) + } +} + +func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + } + + // Check that we get the right initial address and prefix length. + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } + + // Should still get the address when the NIC is diabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..35e5b1a2e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return true } - // If the packet is a TCP packet with a non-unicast source or destination - // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // If the packet is a TCP packet with a unspecified source or non-unicast + // destination address, then do nothing further and instruct the caller to do + // the same. The network layer handles address validation for specified source + // addresses. + if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -677,10 +679,10 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isSpecified(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4d6d62eec..698c8609e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -51,8 +51,8 @@ type testContext struct { // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { @@ -182,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a1458c899..62ab6d92f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -39,7 +39,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo - stack *stack.Stack + proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Abort() { @@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) + r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } @@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -180,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip. } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( + if err := f.proto.stack.RegisterTransportEndpoint( a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct { // fakeTransportProtocol is a transport-layer protocol descriptor. It // aggregates the number of packets received via endpoints of this protocol. type fakeTransportProtocol struct { + stack *stack.Stack + packetCount int controlCount int opts fakeTransportProtocolOptions @@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil } -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -287,26 +288,24 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) + case *tcpip.TCPModerateReceiveBufferOption: + f.opts.good = bool(*v) return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) + case *tcpip.TCPModerateReceiveBufferOption: + *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: return tcpip.ErrUnknownProtocolOption @@ -328,15 +327,15 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { return ok } -func fakeTransFactory() stack.TransportProtocol { - return &fakeTransportProtocol{} +func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { + return &fakeTransportProtocol{stack: s} } func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -406,8 +405,8 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -483,8 +482,8 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -529,54 +528,29 @@ func TestTransportSend(t *testing.T) { func TestTransportOptions(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } + v := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) + } + v = false + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) + } + if !v { + t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } func TestTransportForwarding(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() @@ -631,7 +605,7 @@ func TestTransportForwarding(t *testing.T) { Data: req.ToVectorisedView(), })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 47a8d7c86..c42bb0991 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -237,6 +237,14 @@ type Timer interface { // network node. Or, in the case of unix endpoints, it may represent a path. type Address string +// WithPrefix returns the address with a prefix that represents a point subnet. +func (a Address) WithPrefix() AddressWithPrefix { + return AddressWithPrefix{ + Address: a, + PrefixLen: len(a) * 8, + } +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -561,7 +569,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *waiter.Queue, *Error) + // + // If peerAddr is not nil then it is populated with the peer address of the + // returned endpoint. + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -861,12 +872,93 @@ func (*DefaultTTLOption) isGettableNetworkProtocolOption() {} func (*DefaultTTLOption) isSettableNetworkProtocolOption() {} -// AvailableCongestionControlOption is used to query the supported congestion -// control algorithms. -type AvailableCongestionControlOption string +// GettableTransportProtocolOption is a marker interface for transport protocol +// options that may be queried. +type GettableTransportProtocolOption interface { + isGettableTransportProtocolOption() +} + +// SettableTransportProtocolOption is a marker interface for transport protocol +// options that may be set. +type SettableTransportProtocolOption interface { + isSettableTransportProtocolOption() +} + +// TCPSACKEnabled the SACK option for TCP. +// +// See: https://tools.ietf.org/html/rfc2018. +type TCPSACKEnabled bool + +func (*TCPSACKEnabled) isGettableTransportProtocolOption() {} + +func (*TCPSACKEnabled) isSettableTransportProtocolOption() {} + +// TCPRecovery is the loss deteoction algorithm used by TCP. +type TCPRecovery int32 + +func (*TCPRecovery) isGettableTransportProtocolOption() {} + +func (*TCPRecovery) isSettableTransportProtocolOption() {} + +const ( + // TCPRACKLossDetection indicates RACK is used for loss detection and + // recovery. + TCPRACKLossDetection TCPRecovery = 1 << iota + + // TCPRACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + TCPRACKStaticReoWnd + + // TCPRACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + TCPRACKNoDupTh +) + +// TCPDelayEnabled enables/disables Nagle's algorithm in TCP. +type TCPDelayEnabled bool + +func (*TCPDelayEnabled) isGettableTransportProtocolOption() {} + +func (*TCPDelayEnabled) isSettableTransportProtocolOption() {} + +// TCPSendBufferSizeRangeOption is the send buffer size range for TCP. +type TCPSendBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP. +type TCPReceiveBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPAvailableCongestionControlOption is the supported congestion control +// algorithms for TCP +type TCPAvailableCongestionControlOption string + +func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {} + +func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {} + +// TCPModerateReceiveBufferOption enables/disables receive buffer moderation +// for TCP. +type TCPModerateReceiveBufferOption bool + +func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {} -// ModerateReceiveBufferOption is used by buffer moderation. -type ModerateReceiveBufferOption bool +func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {} // GettableSocketOption is a marker interface for socket options that may be // queried. @@ -932,6 +1024,10 @@ func (*CongestionControlOption) isGettableSocketOption() {} func (*CongestionControlOption) isSettableSocketOption() {} +func (*CongestionControlOption) isGettableTransportProtocolOption() {} + +func (*CongestionControlOption) isSettableTransportProtocolOption() {} + // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -941,6 +1037,10 @@ func (*TCPLingerTimeoutOption) isGettableSocketOption() {} func (*TCPLingerTimeoutOption) isSettableSocketOption() {} +func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {} + // TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TIME_WAIT state // before being marked closed. @@ -950,6 +1050,10 @@ func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {} func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {} +func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {} + // TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a // accept to return a completed connection only when there is data to be // read. This usually means the listening socket will drop the final ACK @@ -968,6 +1072,10 @@ func (*TCPMinRTOOption) isGettableSocketOption() {} func (*TCPMinRTOOption) isSettableSocketOption() {} +func (*TCPMinRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMinRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MaxRTO used by the Stack. type TCPMaxRTOOption time.Duration @@ -976,6 +1084,10 @@ func (*TCPMaxRTOOption) isGettableSocketOption() {} func (*TCPMaxRTOOption) isSettableSocketOption() {} +func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the // maximum number of retransmits after which we time out the connection. type TCPMaxRetriesOption uint64 @@ -984,6 +1096,10 @@ func (*TCPMaxRetriesOption) isGettableSocketOption() {} func (*TCPMaxRetriesOption) isSettableSocketOption() {} +func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. @@ -993,6 +1109,10 @@ func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} +func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} + // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 @@ -1001,6 +1121,10 @@ func (*TCPSynRetriesOption) isGettableSocketOption() {} func (*TCPSynRetriesOption) isSettableSocketOption() {} +func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {} + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -1059,6 +1183,10 @@ func (*TCPTimeWaitReuseOption) isGettableSocketOption() {} func (*TCPTimeWaitReuseOption) isSettableSocketOption() {} +func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {} + const ( // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot // be reused for new connections. @@ -1354,6 +1482,18 @@ type IPStats struct { // MalformedFragmentsReceived is the total number of IP Fragments that were // dropped due to the fragment failing validation checks. MalformedFragmentsReceived *StatCounter + + // IPTablesPreroutingDropped is the total number of IP packets dropped + // in the Prerouting chain. + IPTablesPreroutingDropped *StatCounter + + // IPTablesInputDropped is the total number of IP packets dropped in + // the Input chain. + IPTablesInputDropped *StatCounter + + // IPTablesOutputDropped is the total number of IP packets dropped in + // the Output chain. + IPTablesOutputDropped *StatCounter } // TCPStats collects TCP-specific stats. @@ -1482,9 +1622,6 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter - - // InvalidSourceAddress is the number of invalid sourced datagrams dropped. - InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 1b18023c5..e8caf09ba 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -16,6 +16,7 @@ package integration_test import ( "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -29,6 +30,69 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) + +type ndpDispatcher struct{} + +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { + return false +} + +func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} + +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} + +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return true +} + +func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} + +func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} + +func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} + +// TestInitialLoopbackAddresses tests that the loopback interface does not +// auto-generate a link-local address when it is brought up. +func TestInitialLoopbackAddresses(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDispatcher{}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, nicName string) string { + t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) + return "" + }, + }, + })}, + }) + + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + nicsInfo := s.NICInfo() + if nicInfo, ok := nicsInfo[nicID]; !ok { + t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) + } else if got := len(nicInfo.ProtocolAddresses); got != 0 { + t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) + } +} + // TestLoopbackAcceptAllInSubnet tests that a loopback interface considers // itself bound to all addresses in the subnet of an assigned address. func TestLoopbackAcceptAllInSubnet(t *testing.T) { @@ -120,8 +184,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -187,3 +251,64 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }) } } + +// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address +// in a loopback interface's associated subnet is bound to the permanently bound +// address. +func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { + const nicID = 1 + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + addrBytes := []byte(ipv4Addr.Address) + addrBytes[len(addrBytes)-1]++ + otherAddr := tcpip.Address(addrBytes) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err) + } + defer r.Release() + + params := stack.NetworkHeaderParams{ + Protocol: 111, + TTL: 64, + TOS: stack.DefaultTOS, + } + data := buffer.View([]byte{1, 2, 3, 4}) + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + } + + // Removing the address should make the endpoint invalid. + if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) + } + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 52c27e045..72d86b5ab 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -139,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ipv4Proto := ipv4.NewProtocol() - ipv6Proto := ipv6.NewProtocol() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, }) // We only expect a single packet in response to our ICMP Echo Request. e := channel.New(1, defaultMTU, "") @@ -175,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) { var rxICMP func(*channel.Endpoint, tcpip.Address) var expectedSrc tcpip.Address var expectedDst tcpip.Address - var proto stack.NetworkProtocol + var protoNum tcpip.NetworkProtocolNumber switch l := len(test.dstAddr); l { case header.IPv4AddressSize: rxICMP = rxIPv4ICMP expectedSrc = ipv4Addr.Address expectedDst = remoteIPv4Addr - proto = ipv4Proto + protoNum = header.IPv4ProtocolNumber case header.IPv6AddressSize: rxICMP = rxIPv6ICMP expectedSrc = ipv6Addr.Address expectedDst = remoteIPv6Addr - proto = ipv6Proto + protoNum = header.IPv6ProtocolNumber default: t.Fatalf("got unexpected address length = %d bytes", l) } @@ -204,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -379,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -436,3 +435,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { }) } } + +// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all +// interested endpoints. +func TestReuseAddrAndBroadcast(t *testing.T) { + const ( + nicID = 1 + localPort = 9000 + loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + ) + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + broadcastAddr tcpip.Address + }{ + { + name: "Subnet directed broadcast", + broadcastAddr: loopbackBroadcast, + }, + { + name: "IPv4 broadcast", + broadcastAddr: header.IPv4Broadcast, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x7f\x00\x00\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + // We use the empty subnet instead of just the loopback subnet so we + // also have a route to the IPv4 Broadcast address. + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // We create endpoints that bind to both the wildcard address and the + // broadcast address to make sure both of these types of "broadcast + // interested" endpoints receive broadcast packets. + wq := waiter.Queue{} + var eps []tcpip.Endpoint + for _, bindWildcard := range []bool{false, true} { + // Create multiple endpoints for each type of "broadcast interested" + // endpoint so we can test that all endpoints receive the broadcast + // packet. + for i := 0; i < 2; i++ { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + defer ep.Close() + + if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) + } + + bindAddr := tcpip.FullAddress{Port: localPort} + if bindWildcard { + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } else { + bindAddr.Addr = test.broadcastAddr + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } + + eps = append(eps, ep) + } + } + + for i, wep := range eps { + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.broadcastAddr, + Port: localPort, + }, + } + if n, _, err := wep.Write(data, writeOpts); err != nil { + t.Fatalf("eps[%d].Write(_, _): %s", i, err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + } + + for j, rep := range eps { + if gotPayload, _, err := rep.Read(nil); err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) + } + } + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 346ca4bda..41eb0ca44 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -74,6 +74,8 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -344,9 +346,14 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -415,8 +422,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -430,6 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -462,6 +479,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -597,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 74ef6541e..87d510f96 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -42,6 +37,8 @@ const ( // protocol implements stack.TransportProtocol. type protocol struct { + stack *stack.Stack + number tcpip.TransportProtocolNumber } @@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue) + return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) + return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. @@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -135,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { - return &protocol{ProtocolNumber4} +func NewProtocol4(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { - return &protocol{ProtocolNumber6} +func NewProtocol6(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber6} } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 81093e9ca..072601d2d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -83,6 +83,8 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -192,13 +194,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes return ep.ReadPacket(addr, nil) } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -210,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -267,12 +269,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -298,10 +300,16 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + ep.mu.Lock() + ep.linger = *v + ep.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -366,12 +374,21 @@ func (ep *endpoint) LastError() *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + ep.mu.Lock() + *o = ep.linger + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrNotSupported + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -508,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 71feeb748..e37c00523 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,6 +84,8 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -446,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -482,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -511,10 +513,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch opt.(type) { + switch v := opt.(type) { case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -577,8 +585,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 234fb95ce..518449602 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -69,6 +69,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", @@ -93,6 +94,7 @@ go_test( shard_count = 10, deps = [ ":tcp", + "//pkg/rand", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 72df5c2a1..6891fd245 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -522,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -747,6 +747,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) + pkt.TransportProtocolNumber = header.TCPProtocolNumber tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1002,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1135,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1220,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1232,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1424,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 6074cc24e..560b4904c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv4(t, c.GetPacket(), ackCheckers...) @@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv6(t, c.GetV6Packet(), ackCheckers...) @@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -371,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -510,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -526,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -566,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } const n = uint16(32) @@ -610,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5d9eba5d..ae817091a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -852,12 +875,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -867,12 +890,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.cc = cs } - var mrb tcpip.ModerateReceiveBufferOption + var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -904,7 +927,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result := waiter.EventMask(0) switch e.EndpointState() { - case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: + case StateInitial, StateBound: + // This prevents blocking of new sockets which are not + // connected when SO_LINGER is set. + result |= waiter.EventHUp + + case StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. case StateClose, StateError, StateTimeWait: @@ -1075,6 +1103,8 @@ func (e *endpoint) closeNoShutdownLocked() { e.notifyProtocolGoroutine(notifyClose) } else { e.transitionToStateCloseLocked() + // Notify that the endpoint is closed. + e.waiterQueue.Notify(waiter.EventHUp) } } @@ -1129,10 +1159,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1209,14 +1245,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1293,18 +1327,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1317,14 +1355,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { - // The endpoint cannot be written to if it's not connected. - if !e.EndpointState().connected() { - switch e.EndpointState() { - case StateError: - return 0, e.HardError - default: - return 0, tcpip.ErrClosedForSend - } + switch s := e.EndpointState(); { + case s == StateError: + return 0, e.HardError + case !s.connecting() && !s.connected(): + return 0, tcpip.ErrClosedForSend + case s.connecting(): + // As per RFC793, page 56, a send request arriving when in connecting + // state, can be queued to be completed after the state becomes + // connected. Return an error code for the caller of endpoint Write to + // try again, until the connection handshake is complete. + return 0, tcpip.ErrWouldBlock } // Check if the connection has already been closed for sends. @@ -1478,11 +1519,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1493,17 +1534,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1632,18 +1674,24 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs tcpip.TCPReceiveBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1657,14 +1705,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1672,24 +1715,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss tcpip.TCPSendBufferSizeRangeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -1722,7 +1772,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1771,7 +1821,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. - var avail tcpip.AvailableCongestionControlOption + var avail tcpip.TCPAvailableCongestionControlOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { return err } @@ -2047,8 +2097,10 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.UnlockUser() case *tcpip.OriginalDestinationOption: + e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID) + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + e.UnlockUser() if err != nil { return err } @@ -2486,7 +2538,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2508,6 +2562,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2610,11 +2667,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -2685,13 +2746,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2706,15 +2762,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2722,16 +2780,37 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2811,7 +2890,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return @@ -2880,7 +2959,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 723e47ddc..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,18 +178,18 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c5afa2680..5bce73605 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( @@ -29,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -79,50 +75,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// Recovery is used by stack.(*Stack).TransportProtocolOption to -// set loss detection algorithm in TCP. -type Recovery int32 - -const ( - // RACKLossDetection indicates RACK is used for loss detection and - // recovery. - RACKLossDetection Recovery = 1 << iota - - // RACKStaticReoWnd indicates the reordering window should not be - // adjusted when DSACK is received. - RACKStaticReoWnd - - // RACKNoDupTh indicates RACK should not consider the classic three - // duplicate acknowledgements rule to mark the segments as lost. This - // is used when reordering is not detected. - RACKNoDupTh -) - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -181,12 +133,14 @@ func (s *synRcvdCounter) Threshold() uint64 { } type protocol struct { + stack *stack.Stack + mu sync.RWMutex sackEnabled bool - recovery Recovery + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -207,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid tcp packet size. @@ -244,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { - return false + return stack.UnknownDestinationPacketMalformed } - // There's nothing to do if this is already a reset packet. - if s.flagIsSet(header.TCPFlagRst) { - return true + if !s.flagIsSet(header.TCPFlagRst) { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) } - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) - return true + return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. @@ -296,49 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.Lock() - p.sackEnabled = bool(v) + p.sackEnabled = bool(*v) p.mu.Unlock() return nil - case Recovery: + case *tcpip.TCPRecovery: p.mu.Lock() - p.recovery = Recovery(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -347,75 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.lingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.timeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitReuseOption: - if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.timeWaitReuse = v + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) - } + case *tcpip.TCPMinRTOOption: p.mu.Lock() - p.minRTO = time.Duration(v) + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) - } + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -425,33 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *Recovery: + case *tcpip.TCPRecovery: p.mu.RLock() - *v = Recovery(p.recovery) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -463,15 +420,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil @@ -546,33 +503,19 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TCP header is variable length, peek at it first. - hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) - if !ok { - return false - } - - // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of - // packets. - hdrLen = offset - } - - _, ok = pkt.TransportHeader().Consume(hdrLen) - return ok + return parse.TCP(pkt) } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(s *stack.Stack) stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + stack: s, + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -587,7 +530,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: RACKLossDetection, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5e0bfe585..48bf196d8 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,30 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + if avail == 0 { + // We have no space available to accept any data, move to zero window + // state. + r.rcvWnd = 0 + return r.rcvNxt, 0 + } + + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +212,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -268,14 +287,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait: - // If the ACK acks something not yet sent then we send an ACK. - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() - return true, nil - } - fallthrough - case StateClosing, StateLastAck: + case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -284,9 +296,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo return true, nil } fallthrough - case StateFinWait1: - fallthrough - case StateFinWait2: + case StateFinWait1, StateFinWait2: + // If the ACK acks something not yet sent then we send an ACK. + // + // RFC793, page 37: If the connection is in a synchronized state, + // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, + // TIME-WAIT), any unacceptable segment (out of window sequence number + // or unacceptable acknowledgment number) must elicit only an empty + // acknowledgment segment containing the current send-sequence number + // and an acknowledgment indicating the next sequence number expected + // to be received, and the connection remains in the same state. + // + // Just as on Linux, we do not apply this behavior when state is + // ESTABLISHED. + // Linux receive processing for all states except ESTABLISHED and + // TIME_WAIT is here where if the ACK check fails, we attempt to + // reply back with an ACK with correct seq/ack numbers. + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186 + // The ESTABLISHED state processing is here where if the ACK check + // fails, we ignore the packet: + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., // SHUT_RD) then any data past the rcvNxt should @@ -369,10 +403,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += seqnum.Size(s.segMemSize()) + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +446,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= seqnum.Size(s.segMemSize()) + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil @@ -421,6 +463,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // Just silently drop any RST packets in TIME_WAIT. We do not support // TIME_WAIT assasination as a result we confirm w/ fix 1 as described // in https://tools.ietf.org/html/rfc1337#section-3. + // + // This behavior overrides RFC793 page 70 where we transition to CLOSED + // on receiving RST, which is also default Linux behavior. + // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337. + // + // As we do not yet support PAWS, we are being conservative in ignoring + // RSTs by default. if s.flagIsSet(header.TCPFlagRst) { return false, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 99521f0c1..ef7f5719f 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + opt := tcpip.TCPSACKEnabled(enable) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index adb32e428..5b504d0d1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -240,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetsSentNoICMP confirms that we don't get an ICMP +// DstUnreachable packet when we try send a packet which is not part +// of an active session. +func TestTCPResetsSentNoICMP(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + + // Send a SYN request for a closed port. This should elicit an RST + // but NOT an ICMPv4 DstUnreachable packet. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive whatever comes back. + b := c.GetPacket() + ipHdr := header.IPv4(b) + if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { + t.Errorf("unexpected protocol, got = %d, want = %d", got, want) + } + + // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. + sent := stats.ICMP.V4PacketsSent + if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { + t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) + } +} + // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates // a RST if an ACK is received on the listening socket for which there is no // active handshake in progress and we are not using SYN cookies. @@ -291,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -309,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Lower stackwide TIME_WAIT timeout so that the reservations // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) } c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -348,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -432,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) { // Set TCPLinger to 3 seconds so that sockets are marked closed // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -446,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -488,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence number // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -506,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -527,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -563,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -610,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -631,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -691,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -743,8 +778,8 @@ func TestSimpleReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -933,8 +968,8 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { // Set the SynRcvd threshold to force a syn cookie based accept to happen. opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { @@ -996,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { @@ -1024,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -1061,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -1099,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestSendRstOnListenerRxAckV4(t *testing.T) { @@ -1128,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxAckV6(t *testing.T) { @@ -1156,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestListenShutdown tests for the listening endpoint replying with RST @@ -1272,8 +1307,8 @@ func TestTOSV4(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1321,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1512,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1563,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1574,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) + rcvBufSz := math.MaxUint16 + c.CreateConnected(789, 30000, rcvBufSz) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) @@ -1596,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1617,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1637,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(793), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1679,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1694,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1748,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1762,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { @@ -1781,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), + checker.TCPSeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1827,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, 10) + const rcvBufSz = 10 + c.CreateConnected(789, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1838,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("Read failed: %s", err) } - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -1860,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) @@ -1886,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), + checker.TCPWindow(10), ), ) } @@ -1898,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + iss := seqnum.Value(789) + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1912,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + seqNum := iss.Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1931,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("Timed out waiting for data to arrive") } - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), + // Read the 1 byte payload we just sent. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if got, want := payload, v; !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } + + seqNum = seqNum.Add(1) + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), ), ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + // Receive data and check it. - read := make([]byte, 0, 10) + read := make([]byte, 0, rcvBufSize) for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { @@ -1984,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data = %v, want = %v", read, data) } - // Check that we get an ACK for the newly non-zero window, which is the - // new size. + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), ) } @@ -2017,8 +2109,8 @@ func TestSimpleSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2059,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2081,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2121,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2160,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2196,19 +2288,20 @@ func TestScaledWindowAccept(t *testing.T) { } // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2226,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2277,12 +2370,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2307,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2322,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) // Write chunks of 50000 bytes. - remain := wnd + remain := 0 sent := 0 data := make([]byte, 50000) - for remain > len(data) { + // Keep writing till the window drops below len(data). + for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -2343,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Don't reduce window to zero here. + if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { + remain = wnd << ws + break + } } // Make the window non-zero, but the scaled window zero. - if remain >= 16 { + for remain >= 16 { data = data[:remain-15] c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, @@ -2368,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) + if wnd == 0 { + break + } + remain = wnd << ws } - // Read at least 1MSS of data. An ack should be sent in response to that. + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. sz := 0 - for sz < defaultMTU { + for sz < defaultMTU*2 { v, _, err := c.EP.Read(nil) if err != nil { t.Fatalf("Read failed: %s", err) @@ -2395,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2464,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize+1), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+uint32(i)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2487,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+11), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2535,8 +2646,8 @@ func TestDelay(t *testing.T) { checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2582,8 +2693,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2605,8 +2716,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2667,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2719,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2840,12 +2951,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2867,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(0) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -2895,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2961,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - const wndScale = 2 + const wndScale = 3 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } @@ -2996,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -3017,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -3146,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { defer c.Cleanup() const numRetries = 2 - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { - t.Fatalf("could not set protocol option MaxRetries.\n") + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3206,8 +3319,9 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + opt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3309,8 +3423,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3330,8 +3444,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3352,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3363,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3384,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3408,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3433,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3455,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3483,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3502,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3522,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3543,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3567,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3592,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3608,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3629,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3654,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3675,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3690,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3706,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3798,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.PayloadLen((1<<scale)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3937,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3964,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -3987,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -4012,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(791+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4185,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4204,11 +4319,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4221,11 +4340,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4240,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4252,22 +4375,28 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { + // Set values below the min/2. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } @@ -4278,19 +4407,21 @@ func TestMinMaxBufferSizes(t *testing.T) { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) + // Values above max are capped at max and then doubled. + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) + // Values above max are capped at max and then doubled. + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4339,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() @@ -4626,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) { checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), + checker.TCPSeqNum(seqNum), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4718,8 +4849,8 @@ func TestStackSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption @@ -4751,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) { s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption + var aCC tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -4767,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption + var cc tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) } } @@ -4842,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) { func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -4878,8 +5009,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4912,8 +5043,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4924,8 +5055,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -4950,8 +5081,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(next-1)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4977,8 +5108,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -5018,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5062,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5135,12 +5266,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5152,7 +5283,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5164,12 +5295,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5194,6 +5325,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -5201,53 +5334,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5288,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5296,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") tests := []struct { name string @@ -5347,11 +5486,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5392,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5440,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5476,12 +5611,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5505,8 +5640,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(1) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create TCP endpoint. @@ -5552,12 +5688,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5568,7 +5704,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5617,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5638,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSeqNum(uint32(iss + 1)), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5657,7 +5793,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5672,7 +5808,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5730,12 +5866,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5800,12 +5936,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5853,12 +5989,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5906,13 +6042,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -5931,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { time.Sleep(latency) rawEP.SendPacketWithTS([]byte{1}, tsVal) - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 + payloadSize := receiveBufferSize * 2 + b := make([]byte, int(payloadSize)) + worker := (c.EP).(interface { StopWork() ResumeWork() @@ -5949,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } @@ -5961,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // are waiting to be read. worker.ResumeWork() - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + tcpHdr := header.TCP(header.IPv4(pkt).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) } - rawEP.VerifyACKRcvWnd(0) time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. + // Verify that sending more data when receiveBuffer is exhausted. rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { @@ -5996,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Verify that we receive a non-zero window update ACK. When running // under thread santizer this test can end up sending more than 1 // ack, 1 for the non-zero window - p = c.GetPacket() + p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.TCPAckNum(uint32(wantAckNum)), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { return } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + // We use 10% here as the error margin upwards as the initial window we + // got was afer 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) } }, )) } -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. func TestReceiveBufferAutoTuning(t *testing.T) { const mtu = 1500 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize @@ -6022,26 +6160,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize + tsVal := uint32(rawEP.TSVal) + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { return uint16(rcvWnd >> uint16(c.WindowScale)) } @@ -6058,14 +6203,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) { StopWork() ResumeWork() }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false latency := 1 * time.Millisecond - for i := 0; !done; i++ { + for i := 0; i < 5; i++ { tsVal++ // Stop the worker goroutine. @@ -6087,15 +6226,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Give 1ms for the worker to process the packets. time.Sleep(1 * time.Millisecond) - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break } + lastACK = pkt + } + if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) // Now read all the data from the endpoint and invoke the // moderation API to allow for receive buffer auto-tuning @@ -6120,35 +6264,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ - if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true + pkt := c.GetPacket() + curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied // Increase the latency after first two iterations to // establish a low RTT value in the receiver since it // only tracks the lowest value. This ensures that when @@ -6161,6 +6290,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) { offset += payloadSize payloadSize *= 2 } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + } func TestDelayEnabled(t *testing.T) { @@ -6169,7 +6304,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.TCPDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -6177,17 +6312,17 @@ func TestDelayEnabled(t *testing.T) { } { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.TCPDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } @@ -6293,12 +6428,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6312,8 +6447,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6330,8 +6465,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now send a RST and this should be ignored and not @@ -6359,8 +6494,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6412,12 +6547,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6431,8 +6566,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6449,8 +6584,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Out of order ACK should generate an immediate ACK in @@ -6466,8 +6601,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6519,12 +6654,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6538,8 +6673,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6556,8 +6691,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Send a SYN request w/ sequence number lower than @@ -6602,12 +6737,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6625,8 +6760,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 @@ -6675,12 +6811,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6694,8 +6830,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6712,8 +6848,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) time.Sleep(2 * time.Second) @@ -6727,8 +6863,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Sleep for 4 seconds so at this point we are 1 second past the @@ -6756,8 +6892,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { @@ -6775,8 +6911,9 @@ func TestTCPCloseWithData(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } wq := &waiter.Queue{} @@ -6824,12 +6961,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6855,8 +6992,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now write a few bytes and then close the endpoint. @@ -6874,8 +7011,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6889,8 +7026,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) // First send a partial ACK. @@ -6935,8 +7072,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -6972,8 +7109,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -7007,8 +7144,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7069,8 +7206,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7095,8 +7232,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7112,9 +7249,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } } -func TestIncreaseWindowOnReceive(t *testing.T) { +func TestIncreaseWindowOnRead(t *testing.T) { // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. + // after read() when the window grows by more than 1 MSS. c := context.New(t, defaultMTU) defer c.Cleanup() @@ -7123,10 +7260,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7139,46 +7275,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) { }) sent += len(data) remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Break once the window drops below defaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { + break + } } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + read := 0 + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + for read < defaultMTU*2 { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + read += len(v) } - // After reading two packets, we surely crossed MSS. See the ack: + // After reading > MSS worth of data, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7195,10 +7328,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7212,38 +7344,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { sent += len(data) remain -= len(data) - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), ) } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - // After reading two packets, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7271,8 +7394,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7288,14 +7411,14 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7303,8 +7426,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { @@ -7329,8 +7452,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7341,7 +7464,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) // Send data. This should result in an acceptable endpoint. c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ @@ -7357,14 +7480,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7373,8 +7496,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestResetDuringClose(t *testing.T) { @@ -7399,8 +7522,8 @@ func TestResetDuringClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss.Add(5))))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7462,9 +7585,10 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } for _, tc := range testCases { - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) if got, want := err, tc.err; got != want { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) } if tc.err != nil { continue @@ -7480,3 +7604,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } } + +// generateRandomPayload generates a random byte slice of the specified length +// causing a fatal test failure if it is unable to do so. +func generateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 8edbff964..0f9ed06cd 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.PayloadLen(len(data)+header.TCPMinimumSize+12), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(true, 0, tsVal+1), ), @@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) @@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(false, 0, 0), ), @@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1f5340cd0..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -133,32 +145,39 @@ type Context struct { // WindowScale is the expected window scale in SYN packets sent by // the stack. WindowScale uint8 + + // RcvdWindowScale is the actual window scale sent by the stack in + // SYN/SYN-ACK. + RcvdWindowScale uint8 } // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) } // Increase minimum RTO in tests to avoid test flakes due to early // retransmit in case the test executors are overloaded and cause timers // to fire earlier than expected. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { - t.Fatalf("failed to set stack-wide minRTO: %s", err) + minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) } // Some of the congestion control tests send up to 640 packets, we so @@ -181,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -238,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) { c.CheckNoPacketTimeout(errMsg, 1*time.Second) } -// GetPacket reads a packet from the link layer endpoint and verifies +// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies // that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { +// addresses. If no packet is received in the specified timeout it will return +// nil. +func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { - c.t.Fatalf("Packet wasn't written out") return nil } @@ -257,6 +283,14 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -268,6 +302,21 @@ func (c *Context) GetPacket() []byte { return b } +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + p := c.GetPacketWithTimeout(5 * time.Second) + if p == nil { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + return p +} + // GetPacketNonBlocking reads a packet from the link layer endpoint // and verifies that it is an IPv4 packet with the expected source // and destination address. If no packet is available it will return @@ -284,6 +333,14 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -447,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -474,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -613,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) } tcpHdr := header.TCP(header.IPv4(b).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ @@ -630,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -648,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } + c.RcvdWindowScale = uint8(synOpts.WS) c.Port = tcpHdr.SourcePort() } @@ -719,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) } -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { +// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches +// the provided tsVal as well as returns the original packet. +func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { + r.C.t.Helper() // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPTimestampChecker(true, 0, tsVal), ), ) @@ -737,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) opts := tcpSeg.ParsedOptions() r.RecentTS = opts.TSVal + return ackPacket +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + r.C.t.Helper() + _ = r.VerifyAndReturnACKWithTS(tsVal) } // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK // matches the provided rcvWnd. func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + r.C.t.Helper() ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), + checker.TCPWindow(rcvWnd), ), ) } @@ -768,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPSACKBlockChecker(sackBlocks), ), ) @@ -861,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * tcpCheckers := []checker.TransportChecker{ checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), + checker.TCPSeqNum(uint32(c.IRS) + 1), + checker.TCPAckNum(uint32(iss) + 1), } // Verify that tsEcr of ACK packet is wantOptions.TSVal if the @@ -897,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Mark in context that timestamp option is enabled for this endpoint. c.TimeStampEnabled = true - + c.RcvdWindowScale = uint8(synOptions.WS) return &RawEndpoint{ C: c, SrcPort: tcpSeg.DestinationPort(), @@ -948,12 +1017,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } @@ -990,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) @@ -1028,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // are present. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) + rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) c.IRS = seqnum.Value(tcp.SequenceNumber()) tcpCheckers := []checker.TransportChecker{ checker.SrcPort(StackPort), checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), + checker.TCPAckNum(uint32(iss) + 1), checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), } @@ -1077,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // Send ACK. c.SendPacket(nil, ackHeaders) + c.RcvdWindowScale = uint8(rcvdSynOptions.WS) c.Port = StackPort return &RawEndpoint{ @@ -1096,7 +1168,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.TCPSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index b5d2d0ba6..c78549424 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c74bc4d94..d57ed5d79 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -139,7 +139,7 @@ type endpoint struct { // multicastMemberships that need to be remvoed when the endpoint is // closed. Protected by the mu mutex. - multicastMemberships []multicastMembership + multicastMemberships map[multicastMembership]struct{} // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -154,6 +154,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // +stateify savable @@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // TTL=1. // // Linux defaults to TTL=1. - multicastTTL: 1, - multicastLoop: true, - rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, - state: StateInitial, - uniqueID: s.UniqueID(), + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + multicastMemberships: make(map[multicastMembership]struct{}), + state: StateInitial, + uniqueID: s.UniqueID(), } // Override with stack defaults. @@ -237,10 +241,10 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } - for _, mem := range e.multicastMemberships { + for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = nil + e.multicastMemberships = make(map[multicastMembership]struct{}) // Close the receive list and drain it. e.rcvMu.Lock() @@ -752,17 +756,15 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - for _, mem := range e.multicastMemberships { - if mem == memToInsert { - return tcpip.ErrPortInUse - } + if _, ok := e.multicastMemberships[memToInsert]; ok { + return tcpip.ErrPortInUse } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } - e.multicastMemberships = append(e.multicastMemberships, memToInsert) + e.multicastMemberships[memToInsert] = struct{}{} case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -786,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() - for i, mem := range e.multicastMemberships { - if mem == memToRemove { - memToRemoveIndex = i - break - } - } - if memToRemoveIndex == -1 { + if _, ok := e.multicastMemberships[memToRemove]; !ok { return tcpip.ErrBadLocalAddress } @@ -805,8 +800,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return err } - e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + delete(e.multicastMemberships, memToRemove) case *tcpip.BindToDeviceOption: id := tcpip.NICID(*v) @@ -819,6 +813,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -975,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() + case *tcpip.LingerOption: + e.mu.RLock() + *o = e.linger + e.mu.RUnlock() + default: return tcpip.ErrUnknownProtocolOption } @@ -992,6 +996,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ @@ -1218,7 +1223,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -1392,15 +1397,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - // Never receive from a multicast address. - if header.IsV4MulticastAddress(id.RemoteAddress) || - header.IsV6MulticastAddress(id.RemoteAddress) { - e.stack.Stats().UDP.InvalidSourceAddress.Increment() - e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - return - } - if !verifyChecksum(r, hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 851e6b635..858c99a45 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - for _, m := range e.multicastMemberships { + for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f65751dd4..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,18 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/waiter" @@ -49,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -57,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -79,135 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } -// HandleUnknownDestinationPacket handles packets targeted at this protocol but -// that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +// HandleUnknownDestinationPacket handles packets that are targeted at this +// protocol but don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true + return stack.UnknownDestinationPacketMalformed } if !verifyChecksum(r, hdr, pkt) { - // Checksum Error. r.Stack().Stats().UDP.ChecksumErrors.Increment() - return true + return stack.UnknownDestinationPacketMalformed } - // Only send ICMP error if the address is not a multicast/broadcast - // v4/v6 address or the source is not the unspecified address. - // - // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { - return true - } - - // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination - // Unreachable messages with code: - // - // 2 (Protocol Unreachable), when the designated transport protocol - // is not supported; or - // - // 3 (Port Unreachable), when the designated transport protocol - // (e.g., UDP) is unable to demultiplex the datagram but has no - // protocol mechanism to inform the sender. - switch len(id.LocalAddress) { - case header.IPv4AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() - return true - } - // As per RFC 1812 Section 4.3.2.3 - // - // ICMP datagram SHOULD contain as much of the original - // datagram as possible without the length of the ICMP - // datagram exceeding 576 bytes - // - // NOTE: The above RFC referenced is different from the original - // recommendation in RFC 1122 where it mentioned that at least 8 - // bytes of the payload must be included. Today linux and other - // systems implement the] RFC1812 definition and not the original - // RFC 1122 requirement. - mtu := int(r.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen - payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - - // The buffers used by pkt may be used elsewhere in the system. - // For example, a raw or packet socket may use what UDP - // considers an unreachable destination. Thus we deep copy pkt - // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) - newHeader = append(newHeader, pkt.TransportHeader().View()...) - payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - - case header.IPv6AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() - return true - } - - // As per RFC 4443 section 2.4 - // - // (c) Every ICMPv6 error message (type < 128) MUST include - // as much of the IPv6 offending (invoking) packet (the - // packet that caused the error) as possible without making - // the error message packet exceed the minimum IPv6 MTU - // [IPv6]. - mtu := int(r.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize - available := int(mtu) - headerLen - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - payloadLen := len(network) + len(transport) + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) - payload.Append(pkt.Data) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - } - return true + return stack.UnknownDestinationPacketUnhandled } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -219,11 +111,10 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) - return ok + return parse.UDP(pkt) } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0cbc045d8..4e14a5fc5 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -294,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -388,6 +388,10 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } + if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { + c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -528,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -803,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -924,42 +928,6 @@ func TestReadFromMulticast(t *testing.T) { } } -// TestReadFromMulticaststats checks that a discarded packet -// that that was sent with multicast SOURCE address increments -// the correct counters and that a regular packet does not. -func TestReadFromMulticastStats(t *testing.T) { - t.Helper() - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, false) - - var want uint64 = 0 - if flow.isReverseMulticast() { - want = 1 - } - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { - t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { - t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } - }) - } -} - // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1462,6 +1430,30 @@ func TestNoChecksum(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1479,16 +1471,19 @@ func TestTTL(t *testing.T) { if flow.isMulticast() { wantTTL = multicastTTL } else { - var p stack.NetworkProtocol + var p stack.NetworkProtocolFactory + var n tcpip.NetworkProtocolNumber if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol + n = ipv4.ProtocolNumber } else { - p = ipv6.NewProtocol() + p = ipv6.NewProtocol + n = ipv6.ProtocolNumber } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{p}, + }) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } @@ -1512,18 +1507,6 @@ func TestSetTTL(t *testing.T) { c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - ep.Close() - testWrite(c, flow, checker.TTL(wantTTL)) }) } @@ -2343,17 +2326,18 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - requiresBroadcastOpt: true, + remoteAddr: remNetSubnetBcast, + // TODO(gvisor.dev/issue/3938): Once we support marking a route as + // broadcast, this test should require the broadcast option to be set. + requiresBroadcastOpt: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 052b6b99d..64d17f661 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -22,6 +22,7 @@ import ( "net" "os" "path" + "path/filepath" "regexp" "strconv" "strings" @@ -403,10 +404,13 @@ func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { return } for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return + src := name + if !filepath.IsAbs(src) { + src, err = testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %w", name, err) + return + } } dst := path.Join(dir, path.Base(name)) if err := testutil.Copy(src, dst); err != nil { diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index b7f873392..06fb823f6 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -332,13 +332,13 @@ func PollContext(ctx context.Context, cb func() error) error { } // WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(port int, timeout time.Duration) error { +func WaitForHTTP(ip string, port int, timeout time.Duration) error { cb := func() error { c := &http.Client{ // Calculate timeout to be able to do minimum 5 attempts. Timeout: timeout / 5, } - url := fmt.Sprintf("http://localhost:%d/", port) + url := fmt.Sprintf("http://%s:%d/", ip, port) resp, err := c.Get(url) if err != nil { log.Printf("Waiting %s: %v", url, err) diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 27279b409..9b1e7a085 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -21,7 +21,6 @@ import ( "io" "strconv" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" @@ -184,51 +183,6 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) { return n, err } -// CopyObjectOut copies a fixed-size value or slice of fixed-size values from -// src to the memory mapped at addr in uio. It returns the number of bytes -// copied. -// -// CopyObjectOut must use reflection to encode src; performance-sensitive -// clients should do encoding manually and use uio.CopyOut directly. -// -// Preconditions: Same as IO.CopyOut. -func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { - w := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - // Allocate a byte slice the size of the object being marshaled. This - // adds an extra reflection call, but avoids needing to grow the slice - // during encoding, which can result in many heap-allocated slices. - b := make([]byte, 0, binary.Size(src)) - return w.Write(binary.Marshal(b, ByteOrder, src)) -} - -// CopyObjectIn copies a fixed-size value or slice of fixed-size values from -// the memory mapped at addr in uio to dst. It returns the number of bytes -// copied. -// -// CopyObjectIn must use reflection to decode dst; performance-sensitive -// clients should use uio.CopyIn directly and do decoding manually. -// -// Preconditions: Same as IO.CopyIn. -func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { - r := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - buf := make([]byte, binary.Size(dst)) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - binary.Unmarshal(buf, ByteOrder, dst) - return int(r.Addr - addr), nil -} - // CopyStringIn tuning parameters, defined outside that function for tests. const ( copyStringIncrement = 64 diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index bf3c5df2b..da60b0cc7 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -16,7 +16,6 @@ package usermem import ( "bytes" - "encoding/binary" "fmt" "reflect" "strings" @@ -174,23 +173,6 @@ type testStruct struct { Uint64 uint64 } -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - func TestCopyStringInShort(t *testing.T) { // Tests for string length <= copyStringIncrement. want := strings.Repeat("A", copyStringIncrement-2) |