diff options
Diffstat (limited to 'pkg')
219 files changed, 7604 insertions, 3188 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ecaeb11ac..a461bb65e 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -76,7 +76,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/abi", - "//pkg/binary", "//pkg/bits", "//pkg/marshal", "//pkg/marshal/primitive", @@ -86,9 +85,8 @@ go_library( go_test( name = "linux_test", size = "small", - srcs = ["netfilter_test.go"], - library = ":linux", - deps = [ - "//pkg/binary", + srcs = [ + "netfilter_test.go", ], + library = ":linux", ) diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go index 7c9a02f20..c5713541f 100644 --- a/pkg/abi/linux/elf.go +++ b/pkg/abi/linux/elf.go @@ -106,3 +106,53 @@ const ( // NT_ARM_TLS is for ARM TLS register. NT_ARM_TLS = 0x401 ) + +// ElfHeader64 is the ELF64 file header. +// +// +marshal +type ElfHeader64 struct { + Ident [16]byte // File identification. + Type uint16 // File type. + Machine uint16 // Machine architecture. + Version uint32 // ELF format version. + Entry uint64 // Entry point. + Phoff uint64 // Program header file offset. + Shoff uint64 // Section header file offset. + Flags uint32 // Architecture-specific flags. + Ehsize uint16 // Size of ELF header in bytes. + Phentsize uint16 // Size of program header entry. + Phnum uint16 // Number of program header entries. + Shentsize uint16 // Size of section header entry. + Shnum uint16 // Number of section header entries. + Shstrndx uint16 // Section name strings section. +} + +// ElfSection64 is the ELF64 Section header. +// +// +marshal +type ElfSection64 struct { + Name uint32 // Section name (index into the section header string table). + Type uint32 // Section type. + Flags uint64 // Section flags. + Addr uint64 // Address in memory image. + Off uint64 // Offset in file. + Size uint64 // Size in bytes. + Link uint32 // Index of a related section. + Info uint32 // Depends on section type. + Addralign uint64 // Alignment in bytes. + Entsize uint64 // Size of each entry in section. +} + +// ElfProg64 is the ELF64 Program header. +// +// +marshal +type ElfProg64 struct { + Type uint32 // Entry type. + Flags uint32 // Access permission flags. + Off uint64 // File offset of contents. + Vaddr uint64 // Virtual address in memory image. + Paddr uint64 // Physical address (not used). + Filesz uint64 // Size of contents in file. + Memsz uint64 // Size of contents in memory. + Align uint64 // Alignment in memory and file. +} diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go index 1121a1a92..67706f5aa 100644 --- a/pkg/abi/linux/epoll.go +++ b/pkg/abi/linux/epoll.go @@ -14,10 +14,6 @@ package linux -import ( - "gvisor.dev/gvisor/pkg/binary" -) - // Event masks. const ( EPOLLIN = 0x1 @@ -59,4 +55,4 @@ const ( ) // SizeOfEpollEvent is the size of EpollEvent struct. -var SizeOfEpollEvent = int(binary.Size(EpollEvent{})) +var SizeOfEpollEvent = (*EpollEvent)(nil).SizeBytes() diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errors.go index 93f85a864..b08b2687e 100644 --- a/pkg/abi/linux/errors.go +++ b/pkg/abi/linux/errors.go @@ -15,158 +15,149 @@ package linux // Errno represents a Linux errno value. -type Errno struct { - number int - name string -} - -// Number returns the errno number. -func (e *Errno) Number() int { - return e.number -} - -// String implements fmt.Stringer.String. -func (e *Errno) String() string { - return e.name -} +type Errno int // Errno values from include/uapi/asm-generic/errno-base.h. -var ( - EPERM = &Errno{1, "operation not permitted"} - ENOENT = &Errno{2, "no such file or directory"} - ESRCH = &Errno{3, "no such process"} - EINTR = &Errno{4, "interrupted system call"} - EIO = &Errno{5, "I/O error"} - ENXIO = &Errno{6, "no such device or address"} - E2BIG = &Errno{7, "argument list too long"} - ENOEXEC = &Errno{8, "exec format error"} - EBADF = &Errno{9, "bad file number"} - ECHILD = &Errno{10, "no child processes"} - EAGAIN = &Errno{11, "try again"} - ENOMEM = &Errno{12, "out of memory"} - EACCES = &Errno{13, "permission denied"} - EFAULT = &Errno{14, "bad address"} - ENOTBLK = &Errno{15, "block device required"} - EBUSY = &Errno{16, "device or resource busy"} - EEXIST = &Errno{17, "file exists"} - EXDEV = &Errno{18, "cross-device link"} - ENODEV = &Errno{19, "no such device"} - ENOTDIR = &Errno{20, "not a directory"} - EISDIR = &Errno{21, "is a directory"} - EINVAL = &Errno{22, "invalid argument"} - ENFILE = &Errno{23, "file table overflow"} - EMFILE = &Errno{24, "too many open files"} - ENOTTY = &Errno{25, "not a typewriter"} - ETXTBSY = &Errno{26, "text file busy"} - EFBIG = &Errno{27, "file too large"} - ENOSPC = &Errno{28, "no space left on device"} - ESPIPE = &Errno{29, "illegal seek"} - EROFS = &Errno{30, "read-only file system"} - EMLINK = &Errno{31, "too many links"} - EPIPE = &Errno{32, "broken pipe"} - EDOM = &Errno{33, "math argument out of domain of func"} - ERANGE = &Errno{34, "math result not representable"} +const ( + NOERRNO = iota + EPERM + ENOENT + ESRCH + EINTR + EIO + ENXIO + E2BIG + ENOEXEC + EBADF + ECHILD // 10 + EAGAIN + ENOMEM + EACCES + EFAULT + ENOTBLK + EBUSY + EEXIST + EXDEV + ENODEV + ENOTDIR // 20 + EISDIR + EINVAL + ENFILE + EMFILE + ENOTTY + ETXTBSY + EFBIG + ENOSPC + ESPIPE + EROFS // 30 + EMLINK + EPIPE + EDOM + ERANGE + // Errno values from include/uapi/asm-generic/errno.h. + EDEADLK + ENAMETOOLONG + ENOLCK + ENOSYS + ENOTEMPTY + ELOOP //40 + _ // Skip for EWOULDBLOCK = EAGAIN + ENOMSG //42 + EIDRM + ECHRNG + EL2NSYNC + EL3HLT + EL3RST + ELNRNG + EUNATCH + ENOCSI + EL2HLT // 50 + EBADE + EBADR + EXFULL + ENOANO + EBADRQC + EBADSLT + _ // Skip for EDEADLOCK = EDEADLK + EBFONT + ENOSTR // 60 + ENODATA + ETIME + ENOSR + ENONET + ENOPKG + EREMOTE + ENOLINK + EADV + ESRMNT + ECOMM // 70 + EPROTO + EMULTIHOP + EDOTDOT + EBADMSG + EOVERFLOW + ENOTUNIQ + EBADFD + EREMCHG + ELIBACC + ELIBBAD // 80 + ELIBSCN + ELIBMAX + ELIBEXEC + EILSEQ + ERESTART + ESTRPIPE + EUSERS + ENOTSOCK + EDESTADDRREQ + EMSGSIZE // 90 + EPROTOTYPE + ENOPROTOOPT + EPROTONOSUPPORT + ESOCKTNOSUPPORT + EOPNOTSUPP + EPFNOSUPPORT + EAFNOSUPPORT + EADDRINUSE + EADDRNOTAVAIL + ENETDOWN // 100 + ENETUNREACH + ENETRESET + ECONNABORTED + ECONNRESET + ENOBUFS + EISCONN + ENOTCONN + ESHUTDOWN + ETOOMANYREFS + ETIMEDOUT // 110 + ECONNREFUSED + EHOSTDOWN + EHOSTUNREACH + EALREADY + EINPROGRESS + ESTALE + EUCLEAN + ENOTNAM + ENAVAIL + EISNAM // 120 + EREMOTEIO + EDQUOT + ENOMEDIUM + EMEDIUMTYPE + ECANCELED + ENOKEY + EKEYEXPIRED + EKEYREVOKED + EKEYREJECTED + EOWNERDEAD // 130 + ENOTRECOVERABLE + ERFKILL + EHWPOISON ) -// Errno values from include/uapi/asm-generic/errno.h. -var ( - EDEADLK = &Errno{35, "resource deadlock would occur"} - ENAMETOOLONG = &Errno{36, "file name too long"} - ENOLCK = &Errno{37, "no record locks available"} - ENOSYS = &Errno{38, "invalid system call number"} - ENOTEMPTY = &Errno{39, "directory not empty"} - ELOOP = &Errno{40, "too many symbolic links encountered"} - EWOULDBLOCK = &Errno{EAGAIN.number, "operation would block"} - ENOMSG = &Errno{42, "no message of desired type"} - EIDRM = &Errno{43, "identifier removed"} - ECHRNG = &Errno{44, "channel number out of range"} - EL2NSYNC = &Errno{45, "level 2 not synchronized"} - EL3HLT = &Errno{46, "level 3 halted"} - EL3RST = &Errno{47, "level 3 reset"} - ELNRNG = &Errno{48, "link number out of range"} - EUNATCH = &Errno{49, "protocol driver not attached"} - ENOCSI = &Errno{50, "no CSI structure available"} - EL2HLT = &Errno{51, "level 2 halted"} - EBADE = &Errno{52, "invalid exchange"} - EBADR = &Errno{53, "invalid request descriptor"} - EXFULL = &Errno{54, "exchange full"} - ENOANO = &Errno{55, "no anode"} - EBADRQC = &Errno{56, "invalid request code"} - EBADSLT = &Errno{57, "invalid slot"} - EDEADLOCK = EDEADLK - EBFONT = &Errno{59, "bad font file format"} - ENOSTR = &Errno{60, "device not a stream"} - ENODATA = &Errno{61, "no data available"} - ETIME = &Errno{62, "timer expired"} - ENOSR = &Errno{63, "out of streams resources"} - ENONET = &Errno{64, "machine is not on the network"} - ENOPKG = &Errno{65, "package not installed"} - EREMOTE = &Errno{66, "object is remote"} - ENOLINK = &Errno{67, "link has been severed"} - EADV = &Errno{68, "advertise error"} - ESRMNT = &Errno{69, "srmount error"} - ECOMM = &Errno{70, "communication error on send"} - EPROTO = &Errno{71, "protocol error"} - EMULTIHOP = &Errno{72, "multihop attempted"} - EDOTDOT = &Errno{73, "RFS specific error"} - EBADMSG = &Errno{74, "not a data message"} - EOVERFLOW = &Errno{75, "value too large for defined data type"} - ENOTUNIQ = &Errno{76, "name not unique on network"} - EBADFD = &Errno{77, "file descriptor in bad state"} - EREMCHG = &Errno{78, "remote address changed"} - ELIBACC = &Errno{79, "can not access a needed shared library"} - ELIBBAD = &Errno{80, "accessing a corrupted shared library"} - ELIBSCN = &Errno{81, ".lib section in a.out corrupted"} - ELIBMAX = &Errno{82, "attempting to link in too many shared libraries"} - ELIBEXEC = &Errno{83, "cannot exec a shared library directly"} - EILSEQ = &Errno{84, "illegal byte sequence"} - ERESTART = &Errno{85, "interrupted system call should be restarted"} - ESTRPIPE = &Errno{86, "streams pipe error"} - EUSERS = &Errno{87, "too many users"} - ENOTSOCK = &Errno{88, "socket operation on non-socket"} - EDESTADDRREQ = &Errno{89, "destination address required"} - EMSGSIZE = &Errno{90, "message too long"} - EPROTOTYPE = &Errno{91, "protocol wrong type for socket"} - ENOPROTOOPT = &Errno{92, "protocol not available"} - EPROTONOSUPPORT = &Errno{93, "protocol not supported"} - ESOCKTNOSUPPORT = &Errno{94, "socket type not supported"} - EOPNOTSUPP = &Errno{95, "operation not supported on transport endpoint"} - EPFNOSUPPORT = &Errno{96, "protocol family not supported"} - EAFNOSUPPORT = &Errno{97, "address family not supported by protocol"} - EADDRINUSE = &Errno{98, "address already in use"} - EADDRNOTAVAIL = &Errno{99, "cannot assign requested address"} - ENETDOWN = &Errno{100, "network is down"} - ENETUNREACH = &Errno{101, "network is unreachable"} - ENETRESET = &Errno{102, "network dropped connection because of reset"} - ECONNABORTED = &Errno{103, "software caused connection abort"} - ECONNRESET = &Errno{104, "connection reset by peer"} - ENOBUFS = &Errno{105, "no buffer space available"} - EISCONN = &Errno{106, "transport endpoint is already connected"} - ENOTCONN = &Errno{107, "transport endpoint is not connected"} - ESHUTDOWN = &Errno{108, "cannot send after transport endpoint shutdown"} - ETOOMANYREFS = &Errno{109, "too many references: cannot splice"} - ETIMEDOUT = &Errno{110, "connection timed out"} - ECONNREFUSED = &Errno{111, "connection refused"} - EHOSTDOWN = &Errno{112, "host is down"} - EHOSTUNREACH = &Errno{113, "no route to host"} - EALREADY = &Errno{114, "operation already in progress"} - EINPROGRESS = &Errno{115, "operation now in progress"} - ESTALE = &Errno{116, "stale file handle"} - EUCLEAN = &Errno{117, "structure needs cleaning"} - ENOTNAM = &Errno{118, "not a XENIX named type file"} - ENAVAIL = &Errno{119, "no XENIX semaphores available"} - EISNAM = &Errno{120, "is a named type file"} - EREMOTEIO = &Errno{121, "remote I/O error"} - EDQUOT = &Errno{122, "quota exceeded"} - ENOMEDIUM = &Errno{123, "no medium found"} - EMEDIUMTYPE = &Errno{124, "wrong medium type"} - ECANCELED = &Errno{125, "operation Canceled"} - ENOKEY = &Errno{126, "required key not available"} - EKEYEXPIRED = &Errno{127, "key has expired"} - EKEYREVOKED = &Errno{128, "key has been revoked"} - EKEYREJECTED = &Errno{129, "key was rejected by service"} - EOWNERDEAD = &Errno{130, "owner died"} - ENOTRECOVERABLE = &Errno{131, "state not recoverable"} - ERFKILL = &Errno{132, "operation not possible due to RF-kill"} - EHWPOISON = &Errno{133, "memory page has hardware error"} +// errnos derived from other errnos +const ( + EWOULDBLOCK = EAGAIN + EDEADLOCK = EDEADLK ) diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index e11ca2d62..1e23850a9 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -19,7 +19,6 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/binary" ) // Constants for open(2). @@ -201,7 +200,7 @@ const ( ) // SizeOfStat is the size of a Stat struct. -var SizeOfStat = binary.Size(Stat{}) +var SizeOfStat = (*Stat)(nil).SizeBytes() // Flags for statx. const ( @@ -268,7 +267,7 @@ type Statx struct { } // SizeOfStatx is the size of a Statx struct. -var SizeOfStatx = binary.Size(Statx{}) +var SizeOfStatx = (*Statx)(nil).SizeBytes() // FileMode represents a mode_t. type FileMode uint16 diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 0faf015c7..51a39704b 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -14,8 +14,6 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" - const ( // IFNAMSIZ is the size of the name field for IFReq. IFNAMSIZ = 16 @@ -66,7 +64,7 @@ func (ifr *IFReq) SetName(name string) { } // SizeOfIFReq is the binary size of an IFReq struct (40 bytes). -var SizeOfIFReq = binary.Size(IFReq{}) +var SizeOfIFReq = (*IFReq)(nil).SizeBytes() // IFMap contains interface hardware parameters. type IFMap struct { diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 35c632168..3fd05483a 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -245,6 +245,8 @@ const SizeOfXTCounters = 16 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryMatch struct { MatchSize uint16 Name ExtensionName @@ -284,6 +286,8 @@ const SizeOfXTGetRevision = 30 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryTarget struct { TargetSize uint16 Name ExtensionName @@ -306,6 +310,8 @@ type KernelXTEntryTarget struct { // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTStandardTarget struct { Target XTEntryTarget // A positive verdict indicates a jump, and is the offset from the @@ -322,6 +328,8 @@ const SizeOfXTStandardTarget = 40 // beginning of user-defined chains by putting the name of the chain in // ErrorName. It corresponds to struct xt_error_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTErrorTarget struct { Target XTEntryTarget Name ErrorName @@ -349,6 +357,8 @@ const ( // NfNATIPV4Range corresponds to struct nf_nat_ipv4_range // in include/uapi/linux/netfilter/nf_nat.h. The fields are in // network byte order. +// +// +marshal type NfNATIPV4Range struct { Flags uint32 MinIP [4]byte @@ -359,6 +369,8 @@ type NfNATIPV4Range struct { // NfNATIPV4MultiRangeCompat corresponds to struct // nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NfNATIPV4MultiRangeCompat struct { RangeSize uint32 RangeIPV4 NfNATIPV4Range @@ -366,6 +378,8 @@ type NfNATIPV4MultiRangeCompat struct { // XTRedirectTarget triggers a redirect when reached. // Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal type XTRedirectTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat @@ -377,6 +391,8 @@ const SizeOfXTRedirectTarget = 56 // XTSNATTarget triggers Source NAT when reached. // Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal type XTSNATTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat @@ -463,6 +479,8 @@ var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTReplace struct { Name TableName ValidHooks uint32 @@ -502,6 +520,8 @@ func (tn TableName) String() string { // ErrorName holds the name of a netfilter error. These can also hold // user-defined chains. +// +// +marshal type ErrorName [XT_FUNCTION_MAXNAMELEN]byte // String implements fmt.Stringer. @@ -520,6 +540,8 @@ func goString(cstring []byte) string { // XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTTCP struct { // SourcePortStart specifies the inclusive start of the range of source // ports to which the matcher applies. @@ -573,6 +595,8 @@ const ( // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTUDP struct { // SourcePortStart is the inclusive start of the range of source ports // to which the matcher applies. @@ -613,6 +637,8 @@ const ( // IPTOwnerInfo holds data for matching packets with owner. It corresponds // to struct ipt_owner_info in libxt_owner.c of iptables binary. +// +// +marshal type IPTOwnerInfo struct { // UID is user id which created the packet. UID uint32 @@ -634,7 +660,7 @@ type IPTOwnerInfo struct { Match uint8 // Invert flips the meaning of Match field. - Invert uint8 + Invert uint8 `marshal:"unaligned"` } // SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index f7c70b430..b088b207c 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -264,6 +264,8 @@ const ( // NFNATRange corresponds to struct nf_nat_range in // include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NFNATRange struct { Flags uint32 MinAddr Inet6Addr diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index bf73271c6..600820a0b 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -15,9 +15,8 @@ package linux import ( + "encoding/binary" "testing" - - "gvisor.dev/gvisor/pkg/binary" ) func TestSizes(t *testing.T) { @@ -42,7 +41,7 @@ func TestSizes(t *testing.T) { } for _, tc := range testCases { - if calculated := binary.Size(tc.typ); calculated != tc.defined { + if calculated := uintptr(binary.Size(tc.typ)); calculated != tc.defined { t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated) } } diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index b41f94a69..232fee67e 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -53,6 +53,8 @@ type SockAddrNetlink struct { const SockAddrNetlinkSize = 12 // NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkMessageHeader struct { Length uint32 Type uint16 @@ -99,6 +101,8 @@ const NLMSG_ALIGNTO = 4 // NetlinkAttrHeader is the header of a netlink attribute, followed by payload. // // This is struct nlattr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkAttrHeader struct { Length uint16 Type uint16 @@ -126,6 +130,8 @@ const ( ) // NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkErrorMessage struct { Error int32 Header NetlinkMessageHeader diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index ceda0a8d3..581a11b24 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -85,6 +85,8 @@ const ( ) // InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h. +// +// +marshal type InterfaceInfoMessage struct { Family uint8 _ uint8 @@ -164,6 +166,8 @@ const ( ) // InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h. +// +// +marshal type InterfaceAddrMessage struct { Family uint8 PrefixLen uint8 @@ -193,6 +197,8 @@ const ( ) // RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h. +// +// +marshal type RouteMessage struct { Family uint8 DstLen uint8 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 185eee0bb..95871b8a5 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -15,7 +15,6 @@ package linux import ( - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal" ) @@ -251,18 +250,24 @@ type SockAddrInet struct { } // Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h. +// +// +marshal type Inet6MulticastRequest struct { MulticastAddr Inet6Addr InterfaceIndex int32 } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequest struct { MulticastAddr InetAddr InterfaceAddr InetAddr } // InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequestWithNIC struct { InetMulticastRequest InterfaceIndex int32 @@ -491,7 +496,7 @@ type TCPInfo struct { } // SizeOfTCPInfo is the binary size of a TCPInfo struct. -var SizeOfTCPInfo = int(binary.Size(TCPInfo{})) +var SizeOfTCPInfo = (*TCPInfo)(nil).SizeBytes() // Control message types, from linux/socket.h. const ( @@ -502,6 +507,8 @@ const ( // A ControlMessageHeader is the header for a socket control message. // // ControlMessageHeader represents struct cmsghdr from linux/socket.h. +// +// +marshal type ControlMessageHeader struct { Length uint64 Level int32 @@ -510,7 +517,7 @@ type ControlMessageHeader struct { // SizeOfControlMessageHeader is the binary size of a ControlMessageHeader // struct. -var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) +var SizeOfControlMessageHeader = (*ControlMessageHeader)(nil).SizeBytes() // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // @@ -527,6 +534,7 @@ type ControlMessageCredentials struct { // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. // +// +marshal // +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 @@ -536,7 +544,7 @@ type ControlMessageIPPacketInfo struct { // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. -var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) +var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() // A ControlMessageRights is an SCM_RIGHTS socket control message. type ControlMessageRights []int32 diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 1a30f6967..11072d4de 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -5,6 +5,8 @@ package(licenses = ["notice"]) go_library( name = "atomicbitops", srcs = [ + "aligned_32bit_unsafe.go", + "aligned_64bit.go", "atomicbitops.go", "atomicbitops_amd64.s", "atomicbitops_arm64.s", diff --git a/pkg/atomicbitops/aligned_32bit_unsafe.go b/pkg/atomicbitops/aligned_32bit_unsafe.go new file mode 100644 index 000000000..df706b453 --- /dev/null +++ b/pkg/atomicbitops/aligned_32bit_unsafe.go @@ -0,0 +1,96 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm mips 386 + +package atomicbitops + +import ( + "sync/atomic" + "unsafe" +) + +// AlignedAtomicInt64 is an atomic int64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. +// +// Per https://golang.org/pkg/sync/atomic/#pkg-note-BUG: +// +// "On ARM, 386, and 32-bit MIPS, it is the caller's responsibility to arrange +// for 64-bit alignment of 64-bit words accessed atomically. The first word in +// a variable or in an allocated struct, array, or slice can be relied upon to +// be 64-bit aligned." +// +// +stateify savable +type AlignedAtomicInt64 struct { + value [15]byte +} + +func (aa *AlignedAtomicInt64) ptr() *int64 { + // In the 15-byte aa.value, there are guaranteed to be 8 contiguous + // bytes with 64-bit alignment. We find an address in this range by + // adding 7, then clear the 3 least significant bits to get its start. + return (*int64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value[0])) + 7) &^ 7)) +} + +// Load is analagous to atomic.LoadInt64. +func (aa *AlignedAtomicInt64) Load() int64 { + return atomic.LoadInt64(aa.ptr()) +} + +// Store is analagous to atomic.StoreInt64. +func (aa *AlignedAtomicInt64) Store(v int64) { + atomic.StoreInt64(aa.ptr(), v) +} + +// Add is analagous to atomic.AddInt64. +func (aa *AlignedAtomicInt64) Add(v int64) int64 { + return atomic.AddInt64(aa.ptr(), v) +} + +// AlignedAtomicUint64 is an atomic uint64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. +// +// Per https://golang.org/pkg/sync/atomic/#pkg-note-BUG: +// +// "On ARM, 386, and 32-bit MIPS, it is the caller's responsibility to arrange +// for 64-bit alignment of 64-bit words accessed atomically. The first word in +// a variable or in an allocated struct, array, or slice can be relied upon to +// be 64-bit aligned." +// +// +stateify savable +type AlignedAtomicUint64 struct { + value [15]byte +} + +func (aa *AlignedAtomicUint64) ptr() *uint64 { + // In the 15-byte aa.value, there are guaranteed to be 8 contiguous + // bytes with 64-bit alignment. We find an address in this range by + // adding 7, then clear the 3 least significant bits to get its start. + return (*uint64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value[0])) + 7) &^ 7)) +} + +// Load is analagous to atomic.LoadUint64. +func (aa *AlignedAtomicUint64) Load() uint64 { + return atomic.LoadUint64(aa.ptr()) +} + +// Store is analagous to atomic.StoreUint64. +func (aa *AlignedAtomicUint64) Store(v uint64) { + atomic.StoreUint64(aa.ptr(), v) +} + +// Add is analagous to atomic.AddUint64. +func (aa *AlignedAtomicUint64) Add(v uint64) uint64 { + return atomic.AddUint64(aa.ptr(), v) +} diff --git a/pkg/atomicbitops/aligned_64bit.go b/pkg/atomicbitops/aligned_64bit.go new file mode 100644 index 000000000..1544c7814 --- /dev/null +++ b/pkg/atomicbitops/aligned_64bit.go @@ -0,0 +1,71 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !arm,!mips,!386 + +package atomicbitops + +import "sync/atomic" + +// AlignedAtomicInt64 is an atomic int64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. On most architectures, it's just a regular +// int64. +// +// See aligned_unsafe.go in this directory for justification. +// +// +stateify savable +type AlignedAtomicInt64 struct { + value int64 +} + +// Load is analagous to atomic.LoadInt64. +func (aa *AlignedAtomicInt64) Load() int64 { + return atomic.LoadInt64(&aa.value) +} + +// Store is analagous to atomic.StoreInt64. +func (aa *AlignedAtomicInt64) Store(v int64) { + atomic.StoreInt64(&aa.value, v) +} + +// Add is analagous to atomic.AddInt64. +func (aa *AlignedAtomicInt64) Add(v int64) int64 { + return atomic.AddInt64(&aa.value, v) +} + +// AlignedAtomicUint64 is an atomic uint64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. On most architectures, it's just a regular +// uint64. +// +// See aligned_unsafe.go in this directory for justification. +// +// +stateify savable +type AlignedAtomicUint64 struct { + value uint64 +} + +// Load is analagous to atomic.LoadUint64. +func (aa *AlignedAtomicUint64) Load() uint64 { + return atomic.LoadUint64(&aa.value) +} + +// Store is analagous to atomic.StoreUint64. +func (aa *AlignedAtomicUint64) Store(v uint64) { + atomic.StoreUint64(&aa.value, v) +} + +// Add is analagous to atomic.AddUint64. +func (aa *AlignedAtomicUint64) Add(v uint64) uint64 { + return atomic.AddUint64(&aa.value, v) +} diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go index a26433ad6..d16448c3d 100644 --- a/pkg/bits/bits.go +++ b/pkg/bits/bits.go @@ -14,3 +14,13 @@ // Package bits includes all bit related types and operations. package bits + +// AlignUp rounds a length up to an alignment. align must be a power of 2. +func AlignUp(length int, align uint) int { + return (length + int(align) - 1) & ^(int(align) - 1) +} + +// AlignDown rounds a length down to an alignment. align must be a power of 2. +func AlignDown(length int, align uint) int { + return length & ^(int(align) - 1) +} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index 2a6977f85..c17390522 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -26,6 +26,7 @@ go_test( library = ":bpf", deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/hostarch", + "//pkg/marshal", ], ) diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go index c85d786b9..f64a2dc50 100644 --- a/pkg/bpf/interpreter_test.go +++ b/pkg/bpf/interpreter_test.go @@ -15,10 +15,12 @@ package bpf import ( + "encoding/binary" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" ) func TestCompilationErrors(t *testing.T) { @@ -750,29 +752,29 @@ func TestSimpleFilter(t *testing.T) { // desc is the test's description. desc string - // seccompData is the input data. - seccompData + // SeccompData is the input data. + data linux.SeccompData // expectedRet is the expected return value of the BPF program. expectedRet uint32 }{ { desc: "Invalid arch is rejected", - seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */}, + data: linux.SeccompData{Nr: 1 /* x86 exit */, Arch: 0x40000003 /* AUDIT_ARCH_I386 */}, expectedRet: 0, }, { desc: "Disallowed syscall is rejected", - seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 105 /* __NR_setuid */, Arch: 0xc000003e}, expectedRet: 0, }, { desc: "Allowed syscall is indeed allowed", - seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 231 /* __NR_exit_group */, Arch: 0xc000003e}, expectedRet: 0x7fff0000, }, } { - ret, err := Exec(p, test.seccompData.asInput()) + ret, err := Exec(p, dataAsInput(&test.data)) if err != nil { t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) continue @@ -792,6 +794,6 @@ type seccompData struct { } // asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() Input { - return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +func dataAsInput(data *linux.SeccompData) Input { + return InputBytes{marshal.Marshal(data), hostarch.ByteOrder} } diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index 1186f788e..19cd28a32 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -21,7 +21,6 @@ go_library( "buffer.go", "buffer_list.go", "pool.go", - "safemem.go", "view.go", "view_unsafe.go", ], @@ -29,8 +28,6 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/safemem", - "//pkg/usermem", ], ) @@ -38,13 +35,12 @@ go_test( name = "buffer_test", size = "small", srcs = [ + "buffer_test.go", "pool_test.go", - "safemem_test.go", "view_test.go", ], library = ":buffer", deps = [ - "//pkg/safemem", "//pkg/state", ], ) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 311808ae9..5b77a6a3f 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -33,12 +33,40 @@ func (b *buffer) init(size int) { b.data = make([]byte, size) } +// initWithData initializes b with data, taking ownership. +func (b *buffer) initWithData(data []byte) { + b.data = data + b.read = 0 + b.write = len(data) +} + // Reset resets read and write locations, effectively emptying the buffer. func (b *buffer) Reset() { b.read = 0 b.write = 0 } +// Remove removes r from the unread portion. It returns false if r does not +// fully reside in b. +func (b *buffer) Remove(r Range) bool { + sz := b.ReadSize() + switch { + case r.Len() != r.Intersect(Range{end: sz}).Len(): + return false + case r.Len() == 0: + // Noop + case r.begin == 0: + b.read += r.end + case r.end == sz: + b.write -= r.Len() + default: + // Remove from the middle of b.data. + copy(b.data[b.read+r.begin:], b.data[b.read+r.end:b.write]) + b.write -= r.Len() + } + return true +} + // Full indicates the buffer is full. // // This indicates there is no capacity left to write. diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go new file mode 100644 index 000000000..32db841e4 --- /dev/null +++ b/pkg/buffer/buffer_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "testing" +) + +func TestBufferRemove(t *testing.T) { + sample := []byte("01234567") + + // Success cases + for _, tc := range []struct { + desc string + data []byte + rng Range + want []byte + }{ + { + desc: "empty slice", + }, + { + desc: "empty range", + data: sample, + want: sample, + }, + { + desc: "empty range with positive begin", + data: sample, + rng: Range{begin: 1, end: 1}, + want: sample, + }, + { + desc: "range at beginning", + data: sample, + rng: Range{begin: 0, end: 1}, + want: sample[1:], + }, + { + desc: "range in middle", + data: sample, + rng: Range{begin: 2, end: 4}, + want: []byte("014567"), + }, + { + desc: "range at end", + data: sample, + rng: Range{begin: 7, end: 8}, + want: sample[:7], + }, + { + desc: "range all", + data: sample, + rng: Range{begin: 0, end: 8}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); !ok { + t.Errorf("buf.Remove(%#v) = false, want true", tc.rng) + } else if got := buf.ReadSlice(); !bytes.Equal(got, tc.want) { + t.Errorf("buf.ReadSlice() = %q, want %q", got, tc.want) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + data []byte + rng Range + }{ + { + desc: "begin out-of-range", + data: sample, + rng: Range{begin: -1, end: 4}, + }, + { + desc: "end out-of-range", + data: sample, + rng: Range{begin: 4, end: 9}, + }, + { + desc: "both out-of-range", + data: sample, + rng: Range{begin: -100, end: 100}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); ok { + t.Errorf("buf.Remove(%#v) = true, want false", tc.rng) + } + }) + } +} diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go index 7ad6132ab..2ec41dd4f 100644 --- a/pkg/buffer/pool.go +++ b/pkg/buffer/pool.go @@ -42,6 +42,13 @@ type pool struct { // get gets a new buffer from p. func (p *pool) get() *buffer { + buf := p.getNoInit() + buf.init(p.bufferSize) + return buf +} + +// get gets a new buffer from p without initializing it. +func (p *pool) getNoInit() *buffer { if p.avail == nil { p.avail = p.embeddedStorage[:] } @@ -52,7 +59,6 @@ func (p *pool) get() *buffer { p.bufferSize = defaultBufferSize } buf := &p.avail[0] - buf.init(p.bufferSize) p.avail = p.avail[1:] return buf } @@ -62,6 +68,7 @@ func (p *pool) put(buf *buffer) { // Remove reference to the underlying storage, allowing it to be garbage // collected. buf.data = nil + buf.Reset() } // setBufferSize sets the size of underlying storage buffer for future diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go deleted file mode 100644 index 8b42575b4..000000000 --- a/pkg/buffer/safemem.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "gvisor.dev/gvisor/pkg/safemem" -) - -// WriteBlock returns this buffer as a write Block. -func (b *buffer) WriteBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.WriteSlice()) -} - -// ReadBlock returns this buffer as a read Block. -func (b *buffer) ReadBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.ReadSlice()) -} - -// WriteFromSafememReader writes up to count bytes from r to v and advances the -// write index by the number of bytes written. It calls r.ReadToBlocks() at -// most once. -func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) { - if count == 0 { - return 0, nil - } - - var ( - dst safemem.BlockSeq - blocks []safemem.Block - ) - - // Need at least one buffer. - firstBuf := v.data.Back() - if firstBuf == nil { - firstBuf = v.pool.get() - v.data.PushBack(firstBuf) - } - - // Does the last block have sufficient capacity alone? - if l := uint64(firstBuf.WriteSize()); l >= count { - dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count)) - } else { - // Append blocks until sufficient. - count -= l - blocks = append(blocks, firstBuf.WriteBlock()) - for count > 0 { - emptyBuf := v.pool.get() - v.data.PushBack(emptyBuf) - block := emptyBuf.WriteBlock().TakeFirst64(count) - count -= uint64(block.Len()) - blocks = append(blocks, block) - } - dst = safemem.BlockSeqFromSlice(blocks) - } - - // Perform I/O. - n, err := r.ReadToBlocks(dst) - v.size += int64(n) - - // Update all indices. - for left := n; left > 0; firstBuf = firstBuf.Next() { - if l := firstBuf.WriteSize(); left >= uint64(l) { - firstBuf.WriteMove(l) // Whole block. - left -= uint64(l) - } else { - firstBuf.WriteMove(int(left)) // Partial block. - left = 0 - } - } - - return n, err -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the -// write index by the number of bytes written. -func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes()) -} - -// ReadToSafememWriter reads up to count bytes from v to w. It does not advance -// the read index. It calls w.WriteFromBlocks() at most once. -func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) { - if count == 0 { - return 0, nil - } - - var ( - src safemem.BlockSeq - blocks []safemem.Block - ) - - firstBuf := v.data.Front() - if firstBuf == nil { - return 0, nil // No EOF. - } - - // Is all the data in a single block? - if l := uint64(firstBuf.ReadSize()); l >= count { - src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count)) - } else { - // Build a list of all the buffers. - count -= l - blocks = append(blocks, firstBuf.ReadBlock()) - for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() { - block := buf.ReadBlock().TakeFirst64(count) - count -= uint64(block.Len()) - blocks = append(blocks, block) - } - src = safemem.BlockSeqFromSlice(blocks) - } - - // Perform I/O. As documented, we don't advance the read index. - return w.WriteFromBlocks(src) -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the -// read index by the number of bytes read, such that it's only safe to call if -// the caller guarantees that ReadToBlocks will only be called once. -func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes()) -} diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go deleted file mode 100644 index 721cc5934..000000000 --- a/pkg/buffer/safemem_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/safemem" -) - -func TestSafemem(t *testing.T) { - const bufferSize = defaultBufferSize - - testCases := []struct { - name string - input string - output string - readLen int - op func(*View) - }{ - // Basic coverage. - { - name: "short", - input: "010", - output: "010", - }, - { - name: "long", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize) + "0", - }, - { - name: "short-read", - input: "0", - readLen: 100, // > size. - output: "0", - }, - { - name: "zero-read", - input: "0", - output: "", - }, - { - name: "read-empty", - input: "", - readLen: 1, // > size. - output: "", - }, - - // Ensure offsets work. - { - name: "offsets-short", - input: "012", - output: "2", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize) + "0", - op: func(v *View) { - v.TrimFront(1) - }, - }, - { - name: "offsets-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize-1) + "0", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "10", - op: func(v *View) { - v.TrimFront(bufferSize) - }, - }, - - // Ensure truncation works. - { - name: "truncate-short", - input: "012", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - { - name: "truncate-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize), - op: func(v *View) { - v.Truncate(bufferSize + 1) - }, - }, - { - name: "truncate-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize-1), - op: func(v *View) { - v.Truncate(bufferSize) - }, - }, - { - name: "truncate-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Construct the new view. - var view View - bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input))) - n, err := view.WriteFromBlocks(bs) - if err != nil { - t.Errorf("expected err nil, got %v", err) - } - if n != uint64(len(tc.input)) { - t.Errorf("expected %d bytes, got %d", len(tc.input), n) - } - - // Run the operation. - if tc.op != nil { - tc.op(&view) - } - - // Read and validate. - readLen := tc.readLen - if readLen == 0 { - readLen = len(tc.output) // Default. - } - out := make([]byte, readLen) - bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out)) - n, err = view.ReadToBlocks(bs) - if err != nil { - t.Errorf("expected nil, got %v", err) - } - if n != uint64(len(tc.output)) { - t.Errorf("expected %d bytes, got %d", len(tc.output), n) - } - - // Ensure the contents are correct. - if !bytes.Equal(out[:n], []byte(tc.output[:n])) { - t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out)) - } - }) - } -} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index 00652d675..7bcfcd543 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -19,6 +19,9 @@ import ( "io" ) +// Buffer is an alias to View. +type Buffer = View + // View is a non-linear buffer. // // All methods are thread compatible. @@ -39,6 +42,51 @@ func (v *View) TrimFront(count int64) { } } +// Remove deletes data at specified location in v. It returns false if specified +// range does not fully reside in v. +func (v *View) Remove(offset, length int) bool { + if offset < 0 || length < 0 { + return false + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Len() != tgt.Intersect(Range{end: int(v.size)}).Len() { + return false + } + + // Scan through each buffer and remove intersections. + var curr Range + for buf := v.data.Front(); buf != nil; { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() > 0 { + if !buf.Remove(x.Offset(-curr.begin)) { + panic("buf.Remove() failed") + } + if buf.ReadSize() == 0 { + // buf fully removed, removing it from the list. + oldBuf := buf + buf = buf.Next() + v.data.Remove(oldBuf) + v.pool.put(oldBuf) + } else { + // Only partial data intersects, moving on to next one. + buf = buf.Next() + } + v.size -= int64(x.Len()) + } else { + // This buffer is not in range, moving on to next one. + buf = buf.Next() + } + + curr.begin += origLen + if curr.begin >= tgt.end { + break + } + } + return true +} + // ReadAt implements io.ReaderAt.ReadAt. func (v *View) ReadAt(p []byte, offset int64) (int, error) { var ( @@ -81,7 +129,6 @@ func (v *View) advanceRead(count int64) { oldBuf := buf buf = buf.Next() // Iterate. v.data.Remove(oldBuf) - oldBuf.Reset() v.pool.put(oldBuf) // Update counts. @@ -118,7 +165,6 @@ func (v *View) Truncate(length int64) { // Drop the buffer completely; see above. v.data.Remove(buf) - buf.Reset() v.pool.put(buf) v.size -= sz } @@ -224,6 +270,78 @@ func (v *View) Append(data []byte) { } } +// AppendOwned takes ownership of data and appends it to v. +func (v *View) AppendOwned(data []byte) { + if len(data) > 0 { + buf := v.pool.getNoInit() + buf.initWithData(data) + v.data.PushBack(buf) + v.size += int64(len(data)) + } +} + +// PullUp makes the specified range contiguous and returns the backing memory. +func (v *View) PullUp(offset, length int) ([]byte, bool) { + if length == 0 { + return nil, true + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Intersect(Range{end: int(v.size)}).Len() != length { + return nil, false + } + + curr := Range{} + buf := v.data.Front() + for ; buf != nil; buf = buf.Next() { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() == tgt.Len() { + // buf covers the whole requested target range. + sub := x.Offset(-curr.begin) + return buf.ReadSlice()[sub.begin:sub.end], true + } else if x.Len() > 0 { + // buf is pointing at the starting buffer we want to merge. + break + } + + curr.begin += origLen + } + + // Calculate the total merged length. + totLen := 0 + for n := buf; n != nil; n = n.Next() { + totLen += n.ReadSize() + if curr.begin+totLen >= tgt.end { + break + } + } + + // Merge the buffers. + data := make([]byte, totLen) + off := 0 + for n := buf; n != nil && off < totLen; { + copy(data[off:], n.ReadSlice()) + off += n.ReadSize() + + // Remove buffers except for the first one, which will be reused. + if n == buf { + n = n.Next() + } else { + old := n + n = n.Next() + v.data.Remove(old) + v.pool.put(old) + } + } + + // Update the first buffer with merged data. + buf.initWithData(data) + + r := tgt.Offset(-curr.begin) + return buf.data[r.begin:r.end], true +} + // Flatten returns a flattened copy of this data. // // This method should not be used in any performance-sensitive paths. It may @@ -267,6 +385,27 @@ func (v *View) Apply(fn func([]byte)) { } } +// SubApply applies fn to a given range of data in v. Any part of the range +// outside of v is ignored. +func (v *View) SubApply(offset, length int, fn func([]byte)) { + for buf := v.data.Front(); length > 0 && buf != nil; buf = buf.Next() { + d := buf.ReadSlice() + if offset >= len(d) { + offset -= len(d) + continue + } + if offset > 0 { + d = d[offset:] + offset = 0 + } + if length < len(d) { + d = d[:length] + } + fn(d) + length -= len(d) + } +} + // Merge merges the provided View with this one. // // The other view will be appended to v, and other will be empty after this @@ -389,3 +528,39 @@ func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { } return done, err } + +// A Range specifies a range of buffer. +type Range struct { + begin int + end int +} + +// Intersect returns the intersection of x and y. +func (x Range) Intersect(y Range) Range { + if x.begin < y.begin { + x.begin = y.begin + } + if x.end > y.end { + x.end = y.end + } + if x.begin >= x.end { + return Range{} + } + return x +} + +// Offset returns x offset by off. +func (x Range) Offset(off int) Range { + x.begin += off + x.end += off + return x +} + +// Len returns the length of x. +func (x Range) Len() int { + l := x.end - x.begin + if l < 0 { + l = 0 + } + return l +} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 839af0223..796efa240 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -17,7 +17,9 @@ package buffer import ( "bytes" "context" + "fmt" "io" + "reflect" "strings" "testing" @@ -237,6 +239,18 @@ func TestView(t *testing.T) { }, }, + // AppendOwned. + { + name: "append-owned", + input: "hello", + output: "hello world", + op: func(t *testing.T, v *View) { + b := []byte("Xworld") + v.AppendOwned(b) + b[0] = ' ' + }, + }, + // Truncate. { name: "truncate", @@ -495,6 +509,267 @@ func TestView(t *testing.T) { } } +func TestViewPullUp(t *testing.T) { + for _, tc := range []struct { + desc string + inputs []string + offset int + length int + output string + failed bool + // lengths is the lengths of each buffer node after the pull up. + lengths []int + }{ + { + desc: "whole empty view", + }, + { + desc: "zero pull", + inputs: []string{"hello", " world"}, + lengths: []int{5, 6}, + }, + { + desc: "whole view", + inputs: []string{"hello", " world"}, + offset: 0, + length: 11, + output: "hello world", + lengths: []int{11}, + }, + { + desc: "middle to end aligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 10, + output: "456789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle to end unaligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 6, + length: 8, + output: "6789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle aligned", + inputs: []string{"0123", "45678", "9abcd", "efgh"}, + offset: 6, + length: 5, + output: "6789a", + lengths: []int{4, 10, 4}, + }, + + // Failed cases. + { + desc: "empty view - length too long", + offset: 0, + length: 1, + failed: true, + }, + { + desc: "empty view - offset too large", + offset: 1, + length: 1, + failed: true, + }, + { + desc: "length too long", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 100, + failed: true, + lengths: []int{4, 5, 5}, + }, + { + desc: "offset too large", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 100, + length: 1, + failed: true, + lengths: []int{4, 5, 5}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.inputs { + v.AppendOwned([]byte(s)) + } + + got, gotOk := v.PullUp(tc.offset, tc.length) + want, wantOk := []byte(tc.output), !tc.failed + if gotOk != wantOk || !bytes.Equal(got, want) { + t.Errorf("v.PullUp(%d, %d) = %q, %t; %q, %t", tc.offset, tc.length, got, gotOk, want, wantOk) + } + + var gotLengths []int + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + gotLengths = append(gotLengths, buf.ReadSize()) + } + if !reflect.DeepEqual(gotLengths, tc.lengths) { + t.Errorf("lengths = %v; want %v", gotLengths, tc.lengths) + } + }) + } +} + +func TestViewRemove(t *testing.T) { + // Success cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + // after is the contents for each buffer node after removal. + after []string + offset int + length int + }{ + { + desc: "empty view", + }, + { + desc: "nothing removed", + before: []string{"hello", " world"}, + after: []string{"hello", " world"}, + }, + { + desc: "whole view", + before: []string{"hello", " world"}, + offset: 0, + length: 11, + }, + { + desc: "beginning to middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"9abcd"}, + offset: 0, + length: 9, + }, + { + desc: "beginning to middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"678", "9abcd"}, + offset: 0, + length: 6, + }, + { + desc: "middle to end aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123"}, + offset: 4, + length: 10, + }, + { + desc: "middle to end unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "45"}, + offset: 6, + length: 8, + }, + { + desc: "middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "9abcd"}, + offset: 4, + length: 5, + }, + { + desc: "middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "4578", "9abcd"}, + offset: 6, + length: 1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + + if ok := v.Remove(tc.offset, tc.length); !ok { + t.Errorf("v.Remove(%d, %d) = false, want true", tc.offset, tc.length) + } + + var got []string + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + got = append(got, string(buf.ReadSlice())) + } + if !reflect.DeepEqual(got, tc.after) { + t.Errorf("after = %v; want %v", got, tc.after) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + offset int + length int + }{ + { + desc: "offset out-of-range", + before: []string{"hello", " world"}, + offset: -1, + length: 3, + }, + { + desc: "length too long", + before: []string{"hello", " world"}, + offset: 0, + length: 12, + }, + { + desc: "length too long with positive offset", + before: []string{"hello", " world"}, + offset: 3, + length: 9, + }, + { + desc: "length negative", + before: []string{"hello", " world"}, + offset: 0, + length: -1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + if ok := v.Remove(tc.offset, tc.length); ok { + t.Errorf("v.Remove(%d, %d) = true, want false", tc.offset, tc.length) + } + }) + } +} + +func TestViewSubApply(t *testing.T) { + var v View + v.AppendOwned([]byte("0123")) + v.AppendOwned([]byte("45678")) + v.AppendOwned([]byte("9abcd")) + + data := []byte("0123456789abcd") + + for i := 0; i <= len(data); i++ { + for j := i; j <= len(data); j++ { + t.Run(fmt.Sprintf("SubApply(%d,%d)", i, j), func(t *testing.T) { + var got []byte + v.SubApply(i, j-i, func(b []byte) { + got = append(got, b...) + }) + if want := data[i:j]; !bytes.Equal(got, want) { + t.Errorf("got = %q; want %q", got, want) + } + }) + } + } +} + func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { t.Helper() var buf bytes.Buffer @@ -542,3 +817,84 @@ func TestSaveRestoreView(t *testing.T) { t.Errorf("v.Flatten() = %x, want %x", got, data) } } + +func TestRangeIntersect(t *testing.T) { + for _, tc := range []struct { + desc string + x, y, want Range + }{ + { + desc: "empty intersects empty", + }, + { + desc: "empty intersection", + x: Range{end: 10}, + y: Range{begin: 10, end: 20}, + }, + { + desc: "some intersection", + x: Range{begin: 5, end: 20}, + y: Range{end: 10}, + want: Range{begin: 5, end: 10}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + if got := tc.x.Intersect(tc.y); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.x, tc.y, got, tc.want) + } + if got := tc.y.Intersect(tc.x); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.y, tc.x, got, tc.want) + } + }) + } +} + +func TestRangeOffset(t *testing.T) { + for _, tc := range []struct { + input Range + offset int + output Range + }{ + { + input: Range{}, + offset: 0, + output: Range{}, + }, + { + input: Range{}, + offset: -1, + output: Range{begin: -1, end: -1}, + }, + { + input: Range{begin: 10, end: 20}, + offset: -1, + output: Range{begin: 9, end: 19}, + }, + { + input: Range{begin: 10, end: 20}, + offset: 2, + output: Range{begin: 12, end: 22}, + }, + } { + if got := tc.input.Offset(tc.offset); got != tc.output { + t.Errorf("(%#v).Offset(%d) = %#v, want %#v", tc.input, tc.offset, got, tc.output) + } + } +} + +func TestRangeLen(t *testing.T) { + for _, tc := range []struct { + r Range + want int + }{ + {r: Range{}, want: 0}, + {r: Range{begin: 1, end: 1}, want: 0}, + {r: Range{begin: -1, end: -1}, want: 0}, + {r: Range{end: 10}, want: 10}, + {r: Range{begin: 5, end: 10}, want: 5}, + } { + if got := tc.r.Len(); got != tc.want { + t.Errorf("(%#v).Len() = %d, want %d", tc.r, got, tc.want) + } + } +} diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 1f75319a7..70018cf18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -6,10 +6,7 @@ go_library( name = "compressio", srcs = ["compressio.go"], visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/sync", - ], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index b094c5662..615d7f134 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -48,12 +48,12 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "errors" "hash" "io" "runtime" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sync" ) @@ -130,6 +130,10 @@ type worker struct { hashPool *hashPool input chan *chunk output chan result + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } // work is the main work routine; see worker. @@ -167,7 +171,8 @@ func (w *worker) work(compress bool, level int) { // Write the hash, if enabled. if h != nil { - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) c.h = h h = nil } @@ -175,7 +180,8 @@ func (w *worker) work(compress bool, level int) { // Check the hash of the compressed contents. if h != nil { h.Write(c.compressed.Bytes()) - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) sum := h.Sum(nil) @@ -352,6 +358,10 @@ type Reader struct { // in is the source. in io.Reader + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Reader = (*Reader)(nil) @@ -368,14 +378,15 @@ func NewReader(in io.Reader, key []byte) (*Reader, error) { // Use double buffering for read. r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) - var err error - if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { + if _, err := io.ReadFull(in, r.scratch[:4]); err != nil { return nil, err } + r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4]) if r.hashPool != nil { h := r.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize) + h.Write(r.scratch[:4]) r.lastSum = h.Sum(nil) r.hashPool.putHash(h) sum := make([]byte, len(r.lastSum)) @@ -467,8 +478,7 @@ func (r *Reader) Read(p []byte) (int, error) { // reader. The length is used to limit the reader. // // See writer.flush. - l, err := binary.ReadUint32(r.in, binary.BigEndian) - if err != nil { + if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil { // This is generally okay as long as there // are still buffers outstanding. We actually // just wait for completion of those buffers here @@ -488,6 +498,7 @@ func (r *Reader) Read(p []byte) (int, error) { return done, err } } + l := binary.BigEndian.Uint32(r.scratch[:4]) // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) @@ -573,6 +584,10 @@ type Writer struct { // closed indicates whether the file has been closed. closed bool + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Writer = (*Writer)(nil) @@ -594,13 +609,15 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, } w.init(key, 1+runtime.GOMAXPROCS(0), true, level) - if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return nil, err } if w.hashPool != nil { h := w.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, chunkSize) + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + h.Write(w.scratch[:4]) w.lastSum = h.Sum(nil) w.hashPool.putHash(h) if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { @@ -616,7 +633,9 @@ func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) - if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil { + + binary.BigEndian.PutUint32(w.scratch[:], l) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return err } diff --git a/pkg/linuxerr/BUILD b/pkg/linuxerr/BUILD new file mode 100644 index 000000000..c5abbd34f --- /dev/null +++ b/pkg/linuxerr/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "linuxerr", + srcs = ["linuxerr.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "linuxerr_test", + srcs = ["linuxerr_test.go"], + deps = [ + ":linuxerr", + "//pkg/syserror", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/linuxerr/linuxerr.go b/pkg/linuxerr/linuxerr.go new file mode 100644 index 000000000..f45caaadf --- /dev/null +++ b/pkg/linuxerr/linuxerr.go @@ -0,0 +1,184 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package linuxerr contains syscall error codes exported as an error interface +// pointers. This allows for fast comparison and return operations comperable +// to unix.Errno constants. +package linuxerr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// Error represents a syscall errno with a descriptive message. +type Error struct { + errno linux.Errno + message string +} + +func new(err linux.Errno, message string) *Error { + return &Error{ + errno: err, + message: message, + } +} + +// Error implements error.Error. +func (e *Error) Error() string { return e.message } + +// Errno returns the underlying linux.Errno value. +func (e *Error) Errno() linux.Errno { return e.errno } + +// The following varables have the same meaning as their errno equivalent. + +// Errno values from include/uapi/asm-generic/errno-base.h. +var ( + EPERM = new(linux.EPERM, "operation not permitted") + ENOENT = new(linux.ENOENT, "no such file or directory") + ESRCH = new(linux.ESRCH, "no such process") + EINTR = new(linux.EINTR, "interrupted system call") + EIO = new(linux.EIO, "I/O error") + ENXIO = new(linux.ENXIO, "no such device or address") + E2BIG = new(linux.E2BIG, "argument list too long") + ENOEXEC = new(linux.ENOEXEC, "exec format error") + EBADF = new(linux.EBADF, "bad file number") + ECHILD = new(linux.ECHILD, "no child processes") + EAGAIN = new(linux.EAGAIN, "try again") + ENOMEM = new(linux.ENOMEM, "out of memory") + EACCES = new(linux.EACCES, "permission denied") + EFAULT = new(linux.EFAULT, "bad address") + ENOTBLK = new(linux.ENOTBLK, "block device required") + EBUSY = new(linux.EBUSY, "device or resource busy") + EEXIST = new(linux.EEXIST, "file exists") + EXDEV = new(linux.EXDEV, "cross-device link") + ENODEV = new(linux.ENODEV, "no such device") + ENOTDIR = new(linux.ENOTDIR, "not a directory") + EISDIR = new(linux.EISDIR, "is a directory") + EINVAL = new(linux.EINVAL, "invalid argument") + ENFILE = new(linux.ENFILE, "file table overflow") + EMFILE = new(linux.EMFILE, "too many open files") + ENOTTY = new(linux.ENOTTY, "not a typewriter") + ETXTBSY = new(linux.ETXTBSY, "text file busy") + EFBIG = new(linux.EFBIG, "file too large") + ENOSPC = new(linux.ENOSPC, "no space left on device") + ESPIPE = new(linux.ESPIPE, "illegal seek") + EROFS = new(linux.EROFS, "read-only file system") + EMLINK = new(linux.EMLINK, "too many links") + EPIPE = new(linux.EPIPE, "broken pipe") + EDOM = new(linux.EDOM, "math argument out of domain of func") + ERANGE = new(linux.ERANGE, "math result not representable") +) + +// Errno values from include/uapi/asm-generic/errno.h. +var ( + EDEADLK = new(linux.EDEADLK, "resource deadlock would occur") + ENAMETOOLONG = new(linux.ENAMETOOLONG, "file name too long") + ENOLCK = new(linux.ENOLCK, "no record locks available") + ENOSYS = new(linux.ENOSYS, "invalid system call number") + ENOTEMPTY = new(linux.ENOTEMPTY, "directory not empty") + ELOOP = new(linux.ELOOP, "too many symbolic links encountered") + EWOULDBLOCK = new(linux.EWOULDBLOCK, "operation would block") + ENOMSG = new(linux.ENOMSG, "no message of desired type") + EIDRM = new(linux.EIDRM, "identifier removed") + ECHRNG = new(linux.ECHRNG, "channel number out of range") + EL2NSYNC = new(linux.EL2NSYNC, "level 2 not synchronized") + EL3HLT = new(linux.EL3HLT, "level 3 halted") + EL3RST = new(linux.EL3RST, "level 3 reset") + ELNRNG = new(linux.ELNRNG, "link number out of range") + EUNATCH = new(linux.EUNATCH, "protocol driver not attached") + ENOCSI = new(linux.ENOCSI, "no CSI structure available") + EL2HLT = new(linux.EL2HLT, "level 2 halted") + EBADE = new(linux.EBADE, "invalid exchange") + EBADR = new(linux.EBADR, "invalid request descriptor") + EXFULL = new(linux.EXFULL, "exchange full") + ENOANO = new(linux.ENOANO, "no anode") + EBADRQC = new(linux.EBADRQC, "invalid request code") + EBADSLT = new(linux.EBADSLT, "invalid slot") + EDEADLOCK = new(linux.EDEADLOCK, EDEADLK.message) + EBFONT = new(linux.EBFONT, "bad font file format") + ENOSTR = new(linux.ENOSTR, "device not a stream") + ENODATA = new(linux.ENODATA, "no data available") + ETIME = new(linux.ETIME, "timer expired") + ENOSR = new(linux.ENOSR, "out of streams resources") + ENONET = new(linux.ENOENT, "machine is not on the network") + ENOPKG = new(linux.ENOPKG, "package not installed") + EREMOTE = new(linux.EREMOTE, "object is remote") + ENOLINK = new(linux.ENOLINK, "link has been severed") + EADV = new(linux.EADV, "advertise error") + ESRMNT = new(linux.ESRMNT, "srmount error") + ECOMM = new(linux.ECOMM, "communication error on send") + EPROTO = new(linux.EPROTO, "protocol error") + EMULTIHOP = new(linux.EMULTIHOP, "multihop attempted") + EDOTDOT = new(linux.EDOTDOT, "RFS specific error") + EBADMSG = new(linux.EBADMSG, "not a data message") + EOVERFLOW = new(linux.EOVERFLOW, "value too large for defined data type") + ENOTUNIQ = new(linux.ENOTUNIQ, "name not unique on network") + EBADFD = new(linux.EBADFD, "file descriptor in bad state") + EREMCHG = new(linux.EREMCHG, "remote address changed") + ELIBACC = new(linux.ELIBACC, "can not access a needed shared library") + ELIBBAD = new(linux.ELIBBAD, "accessing a corrupted shared library") + ELIBSCN = new(linux.ELIBSCN, ".lib section in a.out corrupted") + ELIBMAX = new(linux.ELIBMAX, "attempting to link in too many shared libraries") + ELIBEXEC = new(linux.ELIBEXEC, "cannot exec a shared library directly") + EILSEQ = new(linux.EILSEQ, "illegal byte sequence") + ERESTART = new(linux.ERESTART, "interrupted system call should be restarted") + ESTRPIPE = new(linux.ESTRPIPE, "streams pipe error") + EUSERS = new(linux.EUSERS, "too many users") + ENOTSOCK = new(linux.ENOTSOCK, "socket operation on non-socket") + EDESTADDRREQ = new(linux.EDESTADDRREQ, "destination address required") + EMSGSIZE = new(linux.EMSGSIZE, "message too long") + EPROTOTYPE = new(linux.EPROTOTYPE, "protocol wrong type for socket") + ENOPROTOOPT = new(linux.ENOPROTOOPT, "protocol not available") + EPROTONOSUPPORT = new(linux.EPROTONOSUPPORT, "protocol not supported") + ESOCKTNOSUPPORT = new(linux.ESOCKTNOSUPPORT, "socket type not supported") + EOPNOTSUPP = new(linux.EOPNOTSUPP, "operation not supported on transport endpoint") + EPFNOSUPPORT = new(linux.EPFNOSUPPORT, "protocol family not supported") + EAFNOSUPPORT = new(linux.EAFNOSUPPORT, "address family not supported by protocol") + EADDRINUSE = new(linux.EADDRINUSE, "address already in use") + EADDRNOTAVAIL = new(linux.EADDRNOTAVAIL, "cannot assign requested address") + ENETDOWN = new(linux.ENETDOWN, "network is down") + ENETUNREACH = new(linux.ENETUNREACH, "network is unreachable") + ENETRESET = new(linux.ENETRESET, "network dropped connection because of reset") + ECONNABORTED = new(linux.ECONNABORTED, "software caused connection abort") + ECONNRESET = new(linux.ECONNRESET, "connection reset by peer") + ENOBUFS = new(linux.ENOBUFS, "no buffer space available") + EISCONN = new(linux.EISCONN, "transport endpoint is already connected") + ENOTCONN = new(linux.ENOTCONN, "transport endpoint is not connected") + ESHUTDOWN = new(linux.ESHUTDOWN, "cannot send after transport endpoint shutdown") + ETOOMANYREFS = new(linux.ETOOMANYREFS, "too many references: cannot splice") + ETIMEDOUT = new(linux.ETIMEDOUT, "connection timed out") + ECONNREFUSED = new(linux.ECONNREFUSED, "connection refused") + EHOSTDOWN = new(linux.EHOSTDOWN, "host is down") + EHOSTUNREACH = new(linux.EHOSTUNREACH, "no route to host") + EALREADY = new(linux.EALREADY, "operation already in progress") + EINPROGRESS = new(linux.EINPROGRESS, "operation now in progress") + ESTALE = new(linux.ESTALE, "stale file handle") + EUCLEAN = new(linux.EUCLEAN, "structure needs cleaning") + ENOTNAM = new(linux.ENOTNAM, "not a XENIX named type file") + ENAVAIL = new(linux.ENAVAIL, "no XENIX semaphores available") + EISNAM = new(linux.EISNAM, "is a named type file") + EREMOTEIO = new(linux.EREMOTEIO, "remote I/O error") + EDQUOT = new(linux.EDQUOT, "quota exceeded") + ENOMEDIUM = new(linux.ENOMEDIUM, "no medium found") + EMEDIUMTYPE = new(linux.EMEDIUMTYPE, "wrong medium type") + ECANCELED = new(linux.ECANCELED, "operation Canceled") + ENOKEY = new(linux.ENOKEY, "required key not available") + EKEYEXPIRED = new(linux.EKEYEXPIRED, "key has expired") + EKEYREVOKED = new(linux.EKEYREVOKED, "key has been revoked") + EKEYREJECTED = new(linux.EKEYREJECTED, "key was rejected by service") + EOWNERDEAD = new(linux.EOWNERDEAD, "owner died") + ENOTRECOVERABLE = new(linux.ENOTRECOVERABLE, "state not recoverable") + ERFKILL = new(linux.ERFKILL, "operation not possible due to RF-kill") + EHWPOISON = new(linux.EHWPOISON, "memory page has hardware error") +) diff --git a/pkg/syserror/syserror_test.go b/pkg/linuxerr/linuxerr_test.go index c141e5f6e..d34937e93 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/linuxerr/linuxerr_test.go @@ -19,6 +19,7 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -30,7 +31,13 @@ func BenchmarkAssignErrno(b *testing.B) { } } -func BenchmarkAssignError(b *testing.B) { +func BenchmarkLinuxerrAssignError(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = linuxerr.EINVAL + } +} + +func BenchmarkAssignSyserrorError(b *testing.B) { for i := b.N; i > 0; i-- { globalError = syserror.EINVAL } @@ -46,7 +53,17 @@ func BenchmarkCompareErrno(b *testing.B) { } } -func BenchmarkCompareError(b *testing.B) { +func BenchmarkCompareLinuxerrError(b *testing.B) { + globalError = linuxerr.E2BIG + j := 0 + for i := b.N; i > 0; i-- { + if globalError == linuxerr.EINVAL { + j++ + } + } +} + +func BenchmarkCompareSyserrorError(b *testing.B) { globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { @@ -62,7 +79,7 @@ func BenchmarkSwitchErrno(b *testing.B) { for i := b.N; i > 0; i-- { switch globalError { case unix.EINVAL: - j += 1 + j++ case unix.EINTR: j += 2 case unix.EAGAIN: @@ -71,13 +88,28 @@ func BenchmarkSwitchErrno(b *testing.B) { } } -func BenchmarkSwitchError(b *testing.B) { +func BenchmarkSwitchLinuxerrError(b *testing.B) { + globalError = linuxerr.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case linuxerr.EINVAL: + j++ + case linuxerr.EINTR: + j += 2 + case linuxerr.EAGAIN: + j += 3 + } + } +} + +func BenchmarkSwitchSyserrorError(b *testing.B) { globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { case syserror.EINVAL: - j += 1 + j++ case syserror.EINTR: j += 2 case syserror.EAGAIN: diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD index 7cd89e639..7a5002176 100644 --- a/pkg/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "marshal.go", "marshal_impl_util.go", + "util.go", ], visibility = [ "//:sandbox", diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 32c8ed138..6f38992b7 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -125,6 +125,81 @@ func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { var _ marshal.Marshallable = (*ByteSlice)(nil) +// The following set of functions are convenient shorthands for wrapping a +// built-in type in a marshallable primitive type. For example: +// +// func useMarshallable(m marshal.Marshallable) { ... } +// +// // Compare: +// +// buf = []byte{...} +// // useMarshallable(&primitive.ByteSlice(buf)) // Not allowed, can't address temp value. +// bufP := primitive.ByteSlice(buf) +// useMarshallable(&bufP) +// +// // Vs: +// +// useMarshallable(AsByteSlice(buf)) +// +// Note that the argument to these function escapes, so avoid using them on very +// hot code paths. But generally if a function accepts an interface as an +// argument, the argument escapes anyways. + +// AllocateInt8 returns x as a marshallable. +func AllocateInt8(x int8) marshal.Marshallable { + p := Int8(x) + return &p +} + +// AllocateUint8 returns x as a marshallable. +func AllocateUint8(x uint8) marshal.Marshallable { + p := Uint8(x) + return &p +} + +// AllocateInt16 returns x as a marshallable. +func AllocateInt16(x int16) marshal.Marshallable { + p := Int16(x) + return &p +} + +// AllocateUint16 returns x as a marshallable. +func AllocateUint16(x uint16) marshal.Marshallable { + p := Uint16(x) + return &p +} + +// AllocateInt32 returns x as a marshallable. +func AllocateInt32(x int32) marshal.Marshallable { + p := Int32(x) + return &p +} + +// AllocateUint32 returns x as a marshallable. +func AllocateUint32(x uint32) marshal.Marshallable { + p := Uint32(x) + return &p +} + +// AllocateInt64 returns x as a marshallable. +func AllocateInt64(x int64) marshal.Marshallable { + p := Int64(x) + return &p +} + +// AllocateUint64 returns x as a marshallable. +func AllocateUint64(x uint64) marshal.Marshallable { + p := Uint64(x) + return &p +} + +// AsByteSlice returns b as a marshallable. Note that this allocates a new slice +// header, but does not copy the slice contents. +func AsByteSlice(b []byte) marshal.Marshallable { + bs := ByteSlice(b) + return &bs +} + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. diff --git a/pkg/marshal/util.go b/pkg/marshal/util.go new file mode 100644 index 000000000..c1e5475bd --- /dev/null +++ b/pkg/marshal/util.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package marshal + +// Marshal returns the serialized contents of m in a newly allocated +// byte slice. +func Marshal(m Marshallable) []byte { + buf := make([]byte, m.SizeBytes()) + m.MarshalUnsafe(buf) + return buf +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 6450f664c..ac7868ad9 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -36,7 +36,6 @@ const ( ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config SHA384. func DigestSize(hashAlgorithm int) int { switch hashAlgorithm { case linux.FS_VERITY_HASH_ALG_SHA256: @@ -69,7 +68,6 @@ func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) blockSize: hostarch.PageSize, } - // TODO(b/156980949): Allow config SHA384. switch hashAlgorithms { case linux.FS_VERITY_HASH_ALG_SHA256: layout.digestSize = sha256DigestSize @@ -429,8 +427,6 @@ func Verify(params *VerifyParams) (int64, error) { } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. - // TODO(b/162908070): Investigate possible issues with zero - // padding the data. if bytesRead < len(buf) { for j := bytesRead; j < len(buf); j++ { buf[j] = 0 diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index c23a1b52c..fdeee3a5f 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -36,10 +36,17 @@ var ( // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") + // createdSentryMetrics indicates that the sentry metrics are created. + createdSentryMetrics = false + // WeirdnessMetric is a metric with fields created to track the number - // of weird occurrences such as clock fallback, partial_result and - // vsyscall count. + // of weird occurrences such as time fallback, partial_result, vsyscall + // count, watchdog startup timeouts and stuck tasks. WeirdnessMetric *Uint64Metric + + // SuspiciousOperationsMetric is a metric with fields created to detect + // operations such as opening an executable file to write from a gofer. + SuspiciousOperationsMetric *Uint64Metric ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be @@ -388,13 +395,21 @@ func EmitMetricUpdate() { // CreateSentryMetrics creates the sentry metrics during kernel initialization. func CreateSentryMetrics() { - if WeirdnessMetric != nil { + if createdSentryMetrics { return } - WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as clock fallback, partial result and vsyscalls invoked in the sandbox", + createdSentryMetrics = true + + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", Field{ name: "weirdness_type", - allowedValues: []string{"fallback", "partial_result", "vsyscall_count"}, + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, + }) + + SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", + Field{ + name: "operation_type", + allowedValues: []string{"opened_write_execute_file"}, }) } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 7abc82e1b..28396b0ea 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil } +func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, unix.EBADF + } + + if !versionSupportsTmultiGetAttr(c.client.version) { + return DefaultMultiGetAttr(c, names) + } + + rmultigetattr := Rmultigetattr{} + if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil { + return nil, err + } + return rmultigetattr.Stats, nil +} + // StatFS implements File.StatFS. func (c *clientFile) StatFS() (FSStat, error) { if atomic.LoadUint32(&c.closed) != 0 { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index c59c6a65b..97e0231d6 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -15,6 +15,8 @@ package p9 import ( + "errors" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fd" ) @@ -72,6 +74,15 @@ type File interface { // On the server, WalkGetAttr has a read concurrency guarantee. WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) + // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of + // path components similar to Walk(). If the first component name is empty, + // the current file is stat'd and included in the results. If the walk reaches + // a file that doesn't exist or not a directory, MultiGetAttr returns the + // partial result with no error. + // + // On the server, MultiGetAttr has a read concurrency guarantee. + MultiGetAttr(names []string) ([]FullStat, error) + // StatFS returns information about the file system associated with // this file. // @@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { type DisallowServerCalls struct{} // Renamed implements File.Renamed. -func (*clientFile) Renamed(File, string) { +func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } + +// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. +func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { + stats := make([]FullStat, 0, len(names)) + parent := start + mask := AttrMaskAll() + for i, name := range names { + if len(name) == 0 && i == 0 { + qid, valid, attr, err := parent.GetAttr(mask) + if err != nil { + return nil, err + } + stats = append(stats, FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + continue + } + qids, child, valid, attr, err := parent.WalkGetAttr([]string{name}) + if parent != start { + _ = parent.Close() + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + return stats, nil + } + return nil, err + } + stats = append(stats, FullStat{ + QID: qids[0], + Valid: valid, + Attr: attr, + }) + if attr.Mode.FileType() != ModeDirectory { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = child.Close() + break + } + parent = child + } + if parent != start { + _ = parent.Close() + } + return stats, nil +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 58312d0cc..758e11b13 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message { } return rchannel } + +// handle implements handler.handle. +func (t *Tmultigetattr) handle(cs *connState) message { + for i, name := range t.Names { + if len(name) == 0 && i == 0 { + // Empty name is allowed on the first entry to indicate that the current + // FID needs to be included in the result. + continue + } + if err := checkSafeName(name); err != nil { + return newErr(err) + } + } + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(unix.EBADF) + } + defer ref.DecRef() + + var stats []FullStat + if err := ref.safelyRead(func() (err error) { + stats, err = ref.file.MultiGetAttr(t.Names) + return err + }); err != nil { + return newErr(err) + } + return &Rmultigetattr{Stats: stats} +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index cf13cbb69..2ff4694c0 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) { // encode implements encoder.encode. func (r *Rwalk) encode(b *buffer) { b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) { r.Valid.encode(b) r.Attr.encode(b) b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string { return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } +// Tmultigetattr is a multi-getattr request. +type Tmultigetattr struct { + // FID is the FID to be walked. + FID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Tmultigetattr) decode(b *buffer) { + t.FID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Tmultigetattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Tmultigetattr) Type() MsgType { + return MsgTmultigetattr +} + +// String implements fmt.Stringer. +func (t *Tmultigetattr) String() string { + return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names) +} + +// Rmultigetattr is a multi-getattr response. +type Rmultigetattr struct { + // Stats are the set of FullStat returned for each of the names in the + // request. + Stats []FullStat +} + +// decode implements encoder.decode. +func (r *Rmultigetattr) decode(b *buffer) { + n := b.Read16() + r.Stats = r.Stats[:0] + for i := 0; i < int(n); i++ { + var fs FullStat + fs.decode(b) + r.Stats = append(r.Stats, fs) + } +} + +// encode implements encoder.encode. +func (r *Rmultigetattr) encode(b *buffer) { + b.Write16(uint16(len(r.Stats))) + for i := range r.Stats { + r.Stats[i].encode(b) + } +} + +// Type implements message.Type. +func (*Rmultigetattr) Type() MsgType { + return MsgRmultigetattr +} + +// String implements fmt.Stringer. +func (r *Rmultigetattr) String() string { + return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats) +} + const maxCacheSize = 3 // msgFactory is used to reduce allocations by caching messages for reuse. @@ -2717,6 +2791,8 @@ func init() { msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) + msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} }) + msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 648cf4b49..3d452a0bd 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -402,6 +402,8 @@ const ( MsgRallocate MsgType = 139 MsgTsetattrclunk MsgType = 140 MsgRsetattrclunk MsgType = 141 + MsgTmultigetattr MsgType = 142 + MsgRmultigetattr MsgType = 143 MsgTchannel MsgType = 250 MsgRchannel MsgType = 251 ) @@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) { } b.Write32(mask) } + +// FullStat is used in the result of a MultiGetAttr call. +type FullStat struct { + QID QID + Valid AttrMask + Attr Attr +} + +// String implements fmt.Stringer. +func (f *FullStat) String() string { + return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr) +} + +// decode implements encoder.decode. +func (f *FullStat) decode(b *buffer) { + f.QID.decode(b) + f.Valid.decode(b) + f.Attr.decode(b) +} + +// encode implements encoder.encode. +func (f *FullStat) encode(b *buffer) { + f.QID.encode(b) + f.Valid.encode(b) + f.Attr.encode(b) +} diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 8d7168ef5..950236162 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 12 + highestSupportedVersion uint32 = 13 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool { func versionSupportsTsetattrclunk(v uint32) bool { return v >= 12 } + +// versionSupportsTmultiGetAttr returns true if version v supports +// the TmultiGetAttr message. +func versionSupportsTmultiGetAttr(v uint32) bool { + return v >= 13 +} diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 6992e1de8..4aecb8007 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -30,6 +30,9 @@ import ( // RefCounter is the interface to be implemented by objects that are reference // counted. +// +// TODO(gvisor.dev/issue/1624): Get rid of most of this package and replace it +// with refsvfs2. type RefCounter interface { // IncRef increments the reference counter on the object. IncRef() @@ -181,6 +184,9 @@ func (w *WeakRef) zap() { // AtomicRefCount keeps a reference count using atomic operations and calls the // destructor when the count reaches zero. // +// Do not use AtomicRefCount for new ref-counted objects! It is deprecated in +// favor of the refsvfs2 package. +// // N.B. To allow the zero-object to be initialized, the count is offset by // 1, that is, when refCount is n, there are really n+1 references. // @@ -215,8 +221,8 @@ type AtomicRefCount struct { // LeakMode configures the leak checker. type LeakMode uint32 -// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref -// counting is gone. +// TODO(gvisor.dev/issue/1624): Simplify down to two modes (on/off) once vfs1 +// ref counting is gone. const ( // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. UninitializedLeakChecking LeakMode = iota diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD index 0377c0876..7c1a8c792 100644 --- a/pkg/refsvfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -1,3 +1,5 @@ +# TODO(gvisor.dev/issue/1624): rename this directory/package to "refs" once VFS1 +# is gone and the current refs package can be deleted. load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template") diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index 3fbc91aa5..1102c8adc 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -13,7 +13,7 @@ // limitations under the License. // Package refs_template defines a template that can be used by reference -// counted objects. +// counted objects. The template comes with leak checking capabilities. package refs_template import ( @@ -40,6 +40,14 @@ var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. // +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// // +stateify savable type Refs struct { // refCount is composed of two fields: diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 3f17fba49..9dac53c80 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,3 +322,12 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } + +// PrefaultRootTable touches the root table page to be sure that its physical +// pages are mapped. +// +//go:nosplit +//go:noinline +func (p *PageTables) PrefaultRootTable() PTE { + return p.root[0] +} diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s index 290579e53..d513f16c9 100644 --- a/pkg/safecopy/atomic_amd64.s +++ b/pkg/safecopy/atomic_amd64.s @@ -24,12 +24,12 @@ TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 MOVL DI, sig+20(FP) RET -// swapUint32 atomically stores new into *addr and returns (the previous *addr +// swapUint32 atomically stores new into *ptr and returns (the previous ptr* // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) TEXT ·swapUint32(SB), NOSPLIT, $0-24 @@ -38,7 +38,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL new+8(FP), AX XCHGL AX, 0(DI) MOVL AX, old+16(FP) @@ -60,12 +60,12 @@ TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 MOVL DI, sig+24(FP) RET -// swapUint64 atomically stores new into *addr and returns (the previous *addr +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 8-byte boundary. +// Preconditions: ptr must be aligned to a 8-byte boundary. // //func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) TEXT ·swapUint64(SB), NOSPLIT, $0-28 @@ -74,7 +74,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVL $0, sig+24(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVQ new+8(FP), AX XCHGQ AX, 0(DI) MOVQ AX, old+16(FP) @@ -97,11 +97,11 @@ TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 RET // compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is // received during the operation, the value of prev is unspecified, and sig is // the number of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 @@ -111,7 +111,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL old+8(FP), AX MOVL new+12(FP), DX LOCK @@ -135,11 +135,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVL DI, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -148,7 +148,7 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVL $0, sig+12(FP) - MOVQ addr+0(FP), AX + MOVQ ptr+0(FP), AX MOVL (AX), BX MOVL BX, val+8(FP) RET diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s index 55c031a3c..246a049ba 100644 --- a/pkg/safecopy/atomic_arm64.s +++ b/pkg/safecopy/atomic_arm64.s @@ -25,7 +25,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVW $0, sig+20(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW new+8(FP), R1 LDAXRW (R0), R2 STLXRW R1, (R0), R3 @@ -60,7 +60,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVW $0, sig+24(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVD new+8(FP), R1 LDAXR (R0), R2 STLXR R1, (R0), R3 @@ -96,7 +96,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVW $0, sig+20(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 again: @@ -125,11 +125,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVW R1, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -138,7 +138,7 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVW $0, sig+12(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 LDARW (R0), R1 MOVW R1, val+8(FP) RET diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 1d63ca1fd..37316b2f5 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -51,8 +51,8 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36 // handleMemcpyFault will store a different value in this address. MOVL $0, sig+32(FP) - MOVQ to+0(FP), DI - MOVQ from+8(FP), SI + MOVQ dst+0(FP), DI + MOVQ src+8(FP), SI MOVQ n+16(FP), BX tail: diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s index 7b3f50aa5..50f5b754b 100644 --- a/pkg/safecopy/memcpy_arm64.s +++ b/pkg/safecopy/memcpy_arm64.s @@ -33,8 +33,8 @@ TEXT ·memcpy(SB), NOSPLIT, $-8-36 // handleMemcpyFault will store a different value in this address. MOVW $0, sig+32(FP) - MOVD to+0(FP), R3 - MOVD from+8(FP), R4 + MOVD dst+0(FP), R3 + MOVD src+8(FP), R4 MOVD n+16(FP), R5 CMP $0, R5 BNE check diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD new file mode 100644 index 000000000..d09214e3e --- /dev/null +++ b/pkg/sentry/devices/quotedev/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "quotedev", + srcs = ["quotedev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go new file mode 100644 index 000000000..6114cb724 --- /dev/null +++ b/pkg/sentry/devices/quotedev/quotedev.go @@ -0,0 +1,52 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package quotedev implements a vfs.Device for /dev/gvisor_quote. +package quotedev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + quoteDevMinor = 0 +) + +// quoteDevice implements vfs.Device for /dev/gvisor_quote +// +// +stateify savable +type quoteDevice struct{} + +// Open implements vfs.Device.Open. +// TODO(b/157161182): Add support for attestation ioctls. +func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return nil, syserror.EIO +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "gvisor_quote", + }) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index c4a069832..94cb05246 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fd", "//pkg/hostarch", "//pkg/log", + "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 8f5a87120..819e140bc 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -91,7 +92,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") log.Warningf("Opened a writable executable: %q", name) } } diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 1d09afdd7..4893af56b 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -403,7 +403,7 @@ type ipForwarding struct { // enabled stores the IPv4 forwarding state on save. // We must save/restore this here, since a netstack instance // is created on restore. - enabled *bool + enabled bool } func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { @@ -461,13 +461,8 @@ func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return 0, io.EOF } - if f.ipf.enabled == nil { - enabled := f.stack.Forwarding(ipv4.ProtocolNumber) - f.ipf.enabled = &enabled - } - val := "0\n" - if *f.ipf.enabled { + if f.ipf.enabled { // Technically, this is not quite compatible with Linux. Linux // stores these as an integer, so if you write "2" into // ip_forward, you should get 2 back. @@ -494,11 +489,8 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO if err != nil { return n, err } - if f.ipf.enabled == nil { - f.ipf.enabled = new(bool) - } - *f.ipf.enabled = v != 0 - return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) + f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, f.ipf.enabled) } // portRangeInode implements fs.InodeOperations. It provides and allows diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 4cb4741af..51d2be647 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -47,9 +47,7 @@ func (s *tcpSack) afterLoad() { // afterLoad is invoked by stateify. func (ipf *ipForwarding) afterLoad() { - if ipf.enabled != nil { - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { - panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) - } + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil { + panic(fmt.Sprintf("ipf.stack.SetForwarding(%d, %t): %s", ipv4.ProtocolNumber, ipf.enabled, err)) } } diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index 0f54888d8..6512e9cdb 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -68,11 +67,6 @@ func (c *controllerCommon) Enabled() bool { return true } -// Filesystem implements kernel.CgroupController.Filesystem. -func (c *controllerCommon) Filesystem() *vfs.Filesystem { - return c.fs.VFSFilesystem() -} - // RootCgroup implements kernel.CgroupController.RootCgroup. func (c *controllerCommon) RootCgroup() kernel.Cgroup { return c.fs.rootCgroup() diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index bd3e69757..54050de3c 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -109,7 +109,7 @@ type InternalData struct { DefaultControlValues map[string]int64 } -// filesystem implements vfs.FilesystemImpl. +// filesystem implements vfs.FilesystemImpl and kernel.cgroupFS. // // +stateify savable type filesystem struct { @@ -139,6 +139,11 @@ type filesystem struct { tasksMu sync.RWMutex `state:"nosave"` } +// InitializeHierarchyID implements kernel.cgroupFS.InitializeHierarchyID. +func (fs *filesystem) InitializeHierarchyID(hid uint32) { + fs.hierarchyID = hid +} + // Name implements vfs.FilesystemType.Name. func (FilesystemType) Name() string { return Name @@ -284,14 +289,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Register controllers. The registry may be modified concurrently, so if we // get an error, we raced with someone else who registered the same // controllers first. - hid, err := r.Register(fs.kcontrollers) - if err != nil { + if err := r.Register(fs.kcontrollers, fs); err != nil { ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) rootD.DecRef(ctx) fs.VFSFilesystem().DecRef(ctx) return nil, nil, syserror.EBUSY } - fs.hierarchyID = hid // Move all existing tasks to the root of the new hierarchy. k.PopulateNewCgroupHierarchy(fs.rootCgroup()) diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index e6fe0fc0d..daff40cd5 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -36,7 +36,7 @@ const Name = "devtmpfs" // // +stateify savable type FilesystemType struct { - initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported. + initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. initErr error // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7b1eec3da..2dbc6bfd5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -46,7 +46,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 6d5258a9b..368272f12 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "revalidate.go", "save_restore.go", "socket.go", "special_file.go", @@ -53,6 +54,7 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/log", + "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 4b5621043..91ec4a142 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { return ds } +// Precondition: !parent.isSynthetic() && !child.isSynthetic(). +func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) { + // The new child was added to parent and took a ref on the parent (hence + // parent can be removed from cache). A new child has 0 refs for now. So + // checkCachingLocked() should be called on both. Call it first on the parent + // as it may create space in the cache for child to be inserted - hence + // avoiding a cache eviction. + *ds = appendDentry(*ds, parent) + *ds = appendDentry(*ds, child) +} + // Preconditions: ds != nil. func putDentrySlice(ds *[]*dentry) { // Allow dentries to be GC'd. @@ -169,167 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up -// to date. +// * If !d.cachedMetadataAuthoritative(), then d and all children that are +// part of rp must have been revalidated. // // Postconditions: The returned dentry's cached metadata is up to date. -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, false, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, false, err } + followedSymlink := false afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, followedSymlink, nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, false, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil - } - // We must assume that d.parent is correct, because if d has been moved - // elsewhere in the remote filesystem so that its parent has changed, - // we have no way of determining its new parent's location in the - // filesystem. - // - // Call rp.CheckMount() before updating d.parent's metadata, since if - // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return d, followedSymlink, nil } - if d != d.parent && !d.cachedMetadataAuthoritative() { - if err := d.parent.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, false, err } rp.Advance() - return d.parent, nil + return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds) + child, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err - } - if child == nil { - return nil, syserror.ENOENT + return nil, false, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, false, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { - return nil, err + return nil, false, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, false, err } + followedSymlink = true goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, followedSymlink, nil } // getChildLocked returns a dentry representing the child of parent with the -// given name. If no such child exists, getChildLocked returns (nil, nil). +// given name. Returns ENOENT if the child doesn't exist. // // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// -// Postconditions: If getChildLocked returns a non-nil dentry, its cached -// metadata is up to date. -func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +// * dentry at name has been revalidated +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, syserror.ENAMETOOLONG } - child, ok := parent.children[name] - if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() { - // Whether child is nil or not, it is cached information that is - // assumed to be correct. + if child, ok := parent.children[name]; ok || parent.isSynthetic() { + if child == nil { + return nil, syserror.ENOENT + } return child, nil } - // We either don't have cached information or need to verify that it's - // still correct, either of which requires a remote lookup. Check if this - // name is valid before performing the lookup. - return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) -} -// Preconditions: Same as getChildLocked, plus: -// * !parent.isSynthetic(). -func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { - if child != nil { - // Need to lock child.metadataMu because we might be updating child - // metadata. We need to hold the lock *before* getting metadata from the - // server and release it after updating local metadata. - child.metadataMu.Lock() - } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil && err != syserror.ENOENT { - if child != nil { - child.metadataMu.Unlock() + if err != nil { + if err == syserror.ENOENT { + parent.cacheNegativeLookupLocked(name) } return nil, err } - if child != nil { - if !file.isNil() && qid.Path == child.qidPath { - // The file at this path hasn't changed. Just update cached metadata. - file.close(ctx) - child.updateFromP9AttrsLocked(attrMask, &attr) - child.metadataMu.Unlock() - return child, nil - } - child.metadataMu.Unlock() - if file.isNil() && child.isSynthetic() { - // We have a synthetic file, and no remote file has arisen to - // replace it. - return child, nil - } - // The file at this path has changed or no longer exists. Mark the - // dentry invalidated, and re-evaluate its caching status (i.e. if it - // has 0 references, drop it). Wait to update parent.children until we - // know what to replace the existing dentry with (i.e. one of the - // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(ctx, &child.vfsd) - if child.isSynthetic() { - // Normally we don't mark invalidated dentries as deleted since - // they may still exist (but at a different path), and also for - // consistency with Linux. However, synthetic files are guaranteed - // to become unreachable if their dentries are invalidated, so - // treat their invalidation as deletion. - child.setDeleted() - parent.syntheticChildren-- - child.decRefNoCaching() - parent.dirents = nil - } - *ds = appendDentry(*ds, child) - } - if file.isNil() { - // No file exists at this path now. Cache the negative lookup if - // allowed. - parent.cacheNegativeLookupLocked(name) - return nil, nil - } + // Create a new dentry representing the file. - child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) if err != nil { file.close(ctx) delete(parent.children, name) return nil, err } parent.cacheNewChildLocked(child, name) - // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). parent gained a ref so we should also call - // parent.checkCachingLocked() so it can be removed from the cache if needed. - *ds = appendDentry(*ds, child) - *ds = appendDentry(*ds, parent) + appendNewChildDentry(ds, parent, child) return child, nil } @@ -344,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up // to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if !d.isDir() { return nil, syserror.ENOTDIR @@ -364,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: fs.renameMu must be locked. func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { d := rp.Start().Impl().(*dentry) - if !d.cachedMetadataAuthoritative() { - // Get updated metadata for rp.Start() as required by fs.stepLocked(). - if err := d.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err } for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if rp.MustBeDir() && !d.isDir() { return nil, syserror.ENOTDIR @@ -397,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -421,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.isDeleted() { return syserror.ENOENT } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil { + return err + } parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) - switch { - case err != nil && err != syserror.ENOENT: - return err - case child != nil: + if len(name) > maxFilenameLen { + return syserror.ENAMETOOLONG + } + // Check for existence only if caching information is available. Otherwise, + // don't check for existence just yet. We will check for existence if the + // checks for writability fail below. Existence check is done by the creation + // RPCs themselves. + if child, ok := parent.children[name]; ok && child != nil { return syserror.EEXIST } + checkExistence := func() error { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + return err + } else if child != nil { + return syserror.EEXIST + } + return nil + } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } defer mnt.EndWrite() if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } if !dir && rp.MustBeDir() { @@ -489,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -521,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return syserror.EISDIR } } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil { + return err + } + mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, ok := parent.children[name] - if ok && child == nil { - return syserror.ENOENT - } - - sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 - if sticky { - if !ok { - // If the sticky bit is set, we need to retrieve the child to determine - // whether removing it is allowed. - child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err != nil { - return err - } - } else if child != nil && !child.cachedMetadataAuthoritative() { - // Make sure the dentry representing the file at name is up to date - // before examining its metadata. - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } + // Load child if sticky bit is set because we need to determine whether + // deletion is allowed. + var child *dentry + if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 { + var ok bool + child, ok = parent.children[name] + if ok && child == nil { + // Hit a negative cached entry, child doesn't exist. + return syserror.ENOENT + } + } else { + child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err } if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err @@ -556,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // If a child dentry exists, prepare to delete it. This should fail if it is // a mount point. We detect mount points by speculatively calling - // PrepareDeleteDentry, which fails if child is a mount point. However, we - // may need to revalidate the file in this case to make sure that it has not - // been deleted or replaced on the remote fs, in which case the mount point - // will have disappeared. If calling PrepareDeleteDentry fails again on the - // up-to-date dentry, we can be sure that it is a mount point. + // PrepareDeleteDentry, which fails if child is a mount point. // // Also note that if child is nil, then it can't be a mount point. if child != nil { @@ -575,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - // We can skip revalidation in several cases: - // - We are not in InteropModeShared - // - The parent directory is synthetic, in which case the child must also - // be synthetic - // - We already updated the child during the sticky bit check above - if parent.cachedMetadataAuthoritative() || sticky { - return err - } - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } - if child != nil { - if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - return err - } - } + return err } } flags := uint32(0) @@ -723,13 +660,6 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return nil, err @@ -830,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // to creating a synthetic one, i.e. one that is kept entirely in memory. // Check that we're not overriding an existing file with a synthetic one. - _, err = fs.stepLocked(ctx, rp, parent, true, ds) + _, _, err = fs.stepLocked(ctx, rp, parent, true, ds) switch { case err == nil: // Step succeeded, another file exists. @@ -891,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf defer unlock() start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by fs.stepLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if mayCreate && rp.MustBeDir() { @@ -905,6 +829,12 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } + if !start.cachedMetadataAuthoritative() { + // Refresh dentry's attributes before opening. + if err := start.updateFromGetattr(ctx); err != nil { + return nil, err + } + } start.IncRef() defer start.DecRef(ctx) unlock() @@ -926,9 +856,12 @@ afterTrailingSymlink: if mayCreate && rp.MustBeDir() { return nil, syserror.EISDIR } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil { + return nil, err + } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() @@ -1028,7 +961,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } return &fd.vfsfd, nil case linux.S_IFLNK: - // Can't open symlinks without O_PATH (which is unimplemented). + // Can't open symlinks without O_PATH, which is handled at the VFS layer. return nil, syserror.ELOOP case linux.S_IFSOCK: if d.isSynthetic() { @@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -1212,7 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) - *ds = appendDentry(*ds, d) + appendNewChildDentry(ds, d, child) if d.cachedMetadataAuthoritative() { d.touchCMtime() d.dirents = nil @@ -1297,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil { + return err + } + if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds) + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } - if renamed == nil { - return syserror.ENOENT - } if err := oldParent.mayDelete(creds, renamed); err != nil { return err } @@ -1337,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.isDeleted() { return syserror.ENOENT } - replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds) - if err != nil { + replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { return err } var replacedVFSD *vfs.Dentry @@ -1402,9 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // parent isn't actually changing. if oldParent != newParent { oldParent.decRefNoCaching() - ds = appendDentry(ds, oldParent) newParent.IncRef() ds = appendDentry(ds, newParent) + ds = appendDentry(ds, oldParent) if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index fb42c5f62..21692d2ac 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -32,9 +32,9 @@ // specialFileFD.mu // specialFileFD.bufMu // -// Locking dentry.dirMu in multiple dentries requires that either ancestor -// dentries are locked before descendant dentries, or that filesystem.renameMu -// is locked for writing. +// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that +// either ancestor dentries are locked before descendant dentries, or that +// filesystem.renameMu is locked for writing. package gofer import ( diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 21b4a96fe..b0a429d42 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err ctx.UninterruptibleSleepFinish(false) return fdobj, err } + +func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) { + ctx.UninterruptibleSleepStart(false) + stats, err := f.file.MultiGetAttr(names) + ctx.UninterruptibleSleepFinish(false) + return stats, err +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f0e7bbaf7..eed05e369 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -59,7 +60,7 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, return nil, err } if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if atomic.LoadInt32(&d.mmapFD) >= 0 { fsmetric.GoferOpensHost.Increment() diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go new file mode 100644 index 000000000..8f81f0822 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -0,0 +1,386 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +type errPartialRevalidation struct{} + +// Error implements error.Error. +func (errPartialRevalidation) Error() string { + return "partial revalidation" +} + +type errRevalidationStepDone struct{} + +// Error implements error.Error. +func (errRevalidationStepDone) Error() string { + return "stop revalidation" +} + +// revalidatePath checks cached dentries for external modification. File +// attributes are refreshed and cache is invalidated in case the dentry has been +// deleted, or a new file/directory created in its place. +// +// Revalidation stops at symlinks and mount points. The caller is responsible +// for revalidating again after symlinks are resolved and after changing to +// different mounts. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Done, ds) + rp.Release(ctx) + return err +} + +// revalidateParentDir does the same as revalidatePath, but stops at the parent. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file and parent is non synthetic. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Final, ds) + rp.Release(ctx) + return err +} + +// revalidateOne does the same as revalidatePath, but checks a single dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error { + // Skip revalidation for interop mode different than InteropModeShared or + // if the parent is synthetic (child must be synthetic too, but it cannot be + // replaced without first replacing the parent). + if parent.cachedMetadataAuthoritative() { + return nil + } + + parent.dirMu.Lock() + child, ok := parent.children[name] + parent.dirMu.Unlock() + if !ok { + return nil + } + + state := makeRevalidateState(parent) + defer state.release() + + state.add(name, child) + return fs.revalidateHelper(ctx, vfsObj, state, ds) +} + +// revalidate revalidates path components in rp until done returns true, or +// until a mount point or symlink is reached. It may send multiple MultiGetAttr +// calls to the gofer to handle ".." in the path. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error { + state := makeRevalidateState(start) + defer state.release() + + // Skip synthetic dentries because the start dentry cannot be replaced in case + // it has been created in the remote file system. + if !start.isSynthetic() { + state.add("", start) + } + +done: + for cur := start; !done(); { + var err error + cur, err = fs.revalidateStep(ctx, rp, cur, state) + if err != nil { + switch err.(type) { + case errPartialRevalidation: + if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil { + return err + } + + // Reset state to release any remaining locks and restart from where + // stepping stopped. + state.reset() + state.start = cur + + // Skip synthetic dentries because the start dentry cannot be replaced in + // case it has been created in the remote file system. + if !cur.isSynthetic() { + state.add("", cur) + } + + case errRevalidationStepDone: + break done + + default: + return err + } + } + } + return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds) +} + +// revalidateStep walks one element of the path and updates revalidationState +// with the dentry if needed. It may also stop the stepping or ask for a +// partial revalidation. Partial revalidation requires the caller to revalidate +// the current revalidationState, release all locks, and resume stepping. +// In case a symlink is hit, revalidation stops and the caller is responsible +// for calling revalidate again after the symlink is resolved. Revalidation may +// also stop for other reasons, like hitting a child not in the cache. +// +// Returns: +// * (dentry, nil): step worked, continue stepping.` +// * (dentry, errPartialRevalidation): revalidation should be done with the +// state gathered so far. Then continue stepping with the remainder of the +// path, starting at `dentry`. +// * (nil, errRevalidationStepDone): revalidation doesn't need to step any +// further. It hit a symlink, a mount point, or an uncached dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * InteropModeShared is in effect (assumes no negative dentries). +func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) { + switch name := rp.Component(); name { + case ".": + // Do nothing. + + case "..": + // Partial revalidation is required when ".." is hit because metadata locks + // can only be acquired from parent to child to avoid deadlocks. + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } else if isRoot || d.parent == nil { + rp.Advance() + return d, errPartialRevalidation{} + } + // We must assume that d.parent is correct, because if d has been moved + // elsewhere in the remote filesystem so that its parent has changed, + // we have no way of determining its new parent's location in the + // filesystem. + // + // Call rp.CheckMount() before updating d.parent's metadata, since if + // we traverse to another mount then d.parent's metadata is irrelevant. + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } + rp.Advance() + return d.parent, errPartialRevalidation{} + + default: + d.dirMu.Lock() + child, ok := d.children[name] + d.dirMu.Unlock() + if !ok { + // child is not cached, no need to validate any further. + return nil, errRevalidationStepDone{} + } + + state.add(name, child) + + // Symlink must be resolved before continuing with revalidation. + if child.isSymlink() { + return nil, errRevalidationStepDone{} + } + + d = child + } + + rp.Advance() + return d, nil +} + +// revalidateHelper calls the gofer to stat all dentries in `state`. It will +// update or invalidate dentries in the cache based on the result. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error { + if len(state.names) == 0 { + return nil + } + // Lock metadata on all dentries *before* getting attributes for them. + state.lockAllMetadata() + stats, err := state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + + i := -1 + for d := state.popFront(); d != nil; d = state.popFront() { + i++ + found := i < len(stats) + if i == 0 && len(state.names[0]) == 0 { + if found && !d.isSynthetic() { + // First dentry is where the search is starting, just update attributes + // since it cannot be replaced. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + } + d.metadataMu.Unlock() + continue + } + + // Note that synthetic dentries will always fails the comparison check + // below. + if !found || d.qidPath != stats[i].QID.Path { + d.metadataMu.Unlock() + if !found && d.isSynthetic() { + // We have a synthetic file, and no remote file has arisen to replace + // it. + return nil + } + // The file at this path has changed or no longer exists. Mark the + // dentry invalidated, and re-evaluate its caching status (i.e. if it + // has 0 references, drop it). The dentry will be reloaded next time it's + // accessed. + vfsObj.InvalidateDentry(ctx, &d.vfsd) + + name := state.names[i] + d.parent.dirMu.Lock() + + if d.isSynthetic() { + // Normally we don't mark invalidated dentries as deleted since + // they may still exist (but at a different path), and also for + // consistency with Linux. However, synthetic files are guaranteed + // to become unreachable if their dentries are invalidated, so + // treat their invalidation as deletion. + d.setDeleted() + d.decRefNoCaching() + *ds = appendDentry(*ds, d) + + d.parent.syntheticChildren-- + d.parent.dirents = nil + } + + // Since the dirMu was released and reacquired, re-check that the + // parent's child with this name is still the same. Do not touch it if + // it has been replaced with a different one. + if child := d.parent.children[name]; child == d { + // Invalidate dentry so it gets reloaded next time it's accessed. + delete(d.parent.children, name) + } + d.parent.dirMu.Unlock() + + return nil + } + + // The file at this path hasn't changed. Just update cached metadata. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + d.metadataMu.Unlock() + } + + return nil +} + +// revalidateStatePool caches revalidateState instances to save array +// allocations for dentries and names. +var revalidateStatePool = sync.Pool{ + New: func() interface{} { + return &revalidateState{} + }, +} + +// revalidateState keeps state related to a revalidation request. It keeps track +// of {name, dentry} list being revalidated, as well as metadata locks on the +// dentries. The list must be in ancestry order, in other words `n` must be +// `n-1` child. +type revalidateState struct { + // start is the dentry where to start the attributes search. + start *dentry + + // List of names of entries to refresh attributes. Names length must be the + // same as detries length. They are kept in separate slices because names is + // used to call File.MultiGetAttr(). + names []string + + // dentries is the list of dentries that correspond to the names above. + // dentry.metadataMu is acquired as each dentry is added to this list. + dentries []*dentry + + // locked indicates if metadata lock has been acquired on dentries. + locked bool +} + +func makeRevalidateState(start *dentry) *revalidateState { + r := revalidateStatePool.Get().(*revalidateState) + r.start = start + return r +} + +// release must be called after the caller is done with this object. It releases +// all metadata locks and resources. +func (r *revalidateState) release() { + r.reset() + revalidateStatePool.Put(r) +} + +// Preconditions: +// * d is a descendant of all dentries in r.dentries. +func (r *revalidateState) add(name string, d *dentry) { + r.names = append(r.names, name) + r.dentries = append(r.dentries, d) +} + +func (r *revalidateState) lockAllMetadata() { + for _, d := range r.dentries { + d.metadataMu.Lock() + } + r.locked = true +} + +func (r *revalidateState) popFront() *dentry { + if len(r.dentries) == 0 { + return nil + } + d := r.dentries[0] + r.dentries = r.dentries[1:] + return d +} + +// reset releases all metadata locks and resets all fields to allow this +// instance to be reused. +func (r *revalidateState) reset() { + if r.locked { + // Unlock any remaining dentries. + for _, d := range r.dentries { + d.metadataMu.Unlock() + } + r.locked = false + } + r.start = nil + r.names = r.names[:0] + r.dentries = r.dentries[:0] +} diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index ac3b5b621..c12444b7e 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -100,7 +101,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci d.fs.specialFileFDs[fd] = struct{}{} d.fs.syncMu.Unlock() if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if h.fd >= 0 { fsmetric.GoferOpensHost.Increment() diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index badca4d9f..f50b0fb08 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -612,16 +612,24 @@ afterTrailingSymlink: // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) - defer fs.mu.RUnlock() + + fs.mu.RLock() d, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return "", err } if !d.isSymlink() { + fs.mu.RUnlock() return "", syserror.EINVAL } + + // Inode.Readlink() cannot be called holding fs locks. + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() + return d.inode.Readlink(ctx, rp.Mount()) } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 16486eeae..6f699c9cd 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -534,6 +534,9 @@ func (d *Dentry) FSLocalPath() string { // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. // +// Inode functions may be called holding filesystem wide locks and are not +// allowed to call vfs functions that may reenter, unless otherwise noted. +// // Specific responsibilities of implementations are documented below. type Inode interface { // Methods related to reference counting. A generic implementation is @@ -680,6 +683,9 @@ type inodeDirectory interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. + // + // Readlink is called with no kernfs locks held, so it may reenter if needed + // to resolve symlink targets. Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 02bf74dbc..4718fac7a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) defer root.DecRef(ctx) + + // Note: it's safe to reenter kernfs from Readlink if needed to resolve path. return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 7c7543f14..cf905fae4 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -65,6 +65,7 @@ var _ kernfs.Inode = (*tasksInode)(nil) func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ + "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}), "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index e1a8b4409..045ed7a2d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - init := k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - // /proc/version takes the form: // // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) @@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { // FIXME(mpratt): Using Version from the init task SyscallTable // disregards the different version a task may have (e.g., in a uts // namespace). - ver := init.Leader().SyscallTable().Version + ver := kernelVersion(ctx) fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } @@ -400,3 +391,31 @@ func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { r.GenerateProcCgroups(buf) return nil } + +// cmdLineData backs /proc/cmdline. +// +// +stateify savable +type cmdLineData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cmdLineData)(nil) + +// Generate implements vfs.DynamicByteSource.Generate. +func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release) + return nil +} + +// kernelVersion returns the kernel version. +func kernelVersion(ctx context.Context) kernel.Version { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + return init.Leader().SyscallTable().Version +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 9b14dd6b9..88ab49048 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -365,27 +365,22 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { } // ipForwarding implements vfs.WritableDynamicBytesSource for -// /proc/sys/net/ipv4/ip_forwarding. +// /proc/sys/net/ipv4/ip_forward. // // +stateify savable type ipForwarding struct { kernfs.DynamicBytesFile stack inet.Stack `state:"wait"` - enabled *bool + enabled bool } var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error { - if ipf.enabled == nil { - enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber) - ipf.enabled = &enabled - } - val := "0\n" - if *ipf.enabled { + if ipf.enabled { // Technically, this is not quite compatible with Linux. Linux stores these // as an integer, so if you write "2" into tcp_sack, you should get 2 back. // Tough luck. @@ -414,11 +409,8 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs if err != nil { return 0, err } - if ipf.enabled == nil { - ipf.enabled = new(bool) - } - *ipf.enabled = v != 0 - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + ipf.enabled = v != 0 + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil { return 0, err } return n, nil diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index 6cee22823..19b012f7d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -132,7 +132,7 @@ func TestConfigureIPForwarding(t *testing.T) { t.Run(c.comment, func(t *testing.T) { s.IPForwarding = c.initial - file := &ipForwarding{stack: s, enabled: &c.initial} + file := &ipForwarding{stack: s, enabled: c.initial} // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d6f076cd6..e534fbca8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -47,6 +47,7 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ + "cmdline": linux.DT_REG, "cpuinfo": linux.DT_REG, "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 5fdca1d46..766289e60 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -465,7 +465,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } return &fd.vfsfd, nil case *symlink: - // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH. + // Can't open symlinks without O_PATH, which is handled at the VFS layer. return nil, syserror.ELOOP case *namedPipe: return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index ca8090bbf..3582d14c9 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -168,10 +168,6 @@ afterSymlink: // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -// -// TODO(b/166474175): Investigate all possible errors returned in this -// function, and make sure we differentiate all errors that indicate unexpected -// modifications to the file system from the ones that are not harmful. func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() @@ -278,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi var buf bytes.Buffer parent.hashMu.RLock() _, err = merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, - Children: parent.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), @@ -409,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry var buf bytes.Buffer d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - Children: d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. @@ -991,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } // StatAt implements vfs.FilesystemImpl.StatAt. -// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should -// be verified. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 458c7fcb6..fa7696ad6 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -840,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) { // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - // TODO(b/162788573): Add integrity check for metadata. stat, err := fd.lowerFD.Stat(ctx, opts) if err != nil { return linux.Statx{}, err @@ -960,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui } params := &merkletree.GenerateParams{ - TreeReader: &merkleReader, - TreeWriter: &merkleWriter, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + TreeReader: &merkleReader, + TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), Name: fd.d.name, Mode: uint32(stat.Mode), @@ -1192,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. case linux.FS_IOC_GETFLAGS: return fd.verityFlags(ctx, args[2].Pointer()) default: - // TODO(b/169682228): Investigate which ioctl commands should - // be allowed. return 0, syserror.ENOSYS } } @@ -1253,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), @@ -1304,6 +1299,11 @@ func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapO return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) } +// SupportsLocks implements vfs.FileDescriptionImpl.SupportsLocks. +func (fd *fileDescription) SupportsLocks() bool { + return fd.lowerFD.SupportsLocks() +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1333,7 +1333,7 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) if err != nil { - return ts, err + return nil, err } // dataSize is the size of the whole file. @@ -1346,17 +1346,17 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { - return ts, err + return nil, err } // The dataSize xattr should be an integer. If it's not, it indicates // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } merkleReader := FileReadWriteSeeker{ @@ -1389,7 +1389,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem DataAndTreeInSameFile: false, }) if err != nil { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } } return ts, err diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go index 7e535b527..17d0d5025 100644 --- a/pkg/sentry/fsmetric/fsmetric.go +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -42,7 +42,6 @@ var ( // Metrics that only apply to fs/gofer and fsimpl/gofer. var ( - GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 6b71bd3a9..80dda1559 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -88,9 +88,6 @@ type Stack interface { // for restoring a stack after a save. RestoreCleanupEndpoints([]stack.TransportEndpoint) - // Forwarding returns if packet forwarding between NICs is enabled. - Forwarding(protocol tcpip.NetworkProtocolNumber) bool - // SetForwarding enables or disables packet forwarding between NICs. SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 03e2608c2..218d9dafc 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -154,11 +154,6 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// Forwarding implements inet.Stack.Forwarding. -func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - return s.IPForwarding -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { s.IPForwarding = enable diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index 1f1c63f37..0fbf27f64 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -48,10 +48,6 @@ type CgroupController interface { // attached to. Returned value is valid for the lifetime of the controller. HierarchyID() uint32 - // Filesystem returns the filesystem this controller is attached to. - // Returned value is valid for the lifetime of the controller. - Filesystem() *vfs.Filesystem - // RootCgroup returns the root cgroup for this controller. Returned value is // valid for the lifetime of the controller. RootCgroup() Cgroup @@ -124,6 +120,19 @@ func (h *hierarchy) match(ctypes []CgroupControllerType) bool { return true } +// cgroupFS is the public interface to cgroupfs. This lets the kernel package +// refer to cgroupfs.filesystem methods without directly depending on the +// cgroupfs package, which would lead to a circular dependency. +type cgroupFS interface { + // Returns the vfs.Filesystem for the cgroupfs. + VFSFilesystem() *vfs.Filesystem + + // InitializeHierarchyID sets the hierarchy ID for this filesystem during + // filesystem creation. May only be called before the filesystem is visible + // to the vfs layer. + InitializeHierarchyID(hid uint32) +} + // CgroupRegistry tracks the active set of cgroup controllers on the system. // // +stateify savable @@ -182,31 +191,35 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files // Register registers the provided set of controllers with the registry as a new // hierarchy. If any controller is already registered, the function returns an -// error without modifying the registry. The hierarchy can be later referenced -// by the returned id. -func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { +// error without modifying the registry. Register sets the hierarchy ID for the +// filesystem on success. +func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error { r.mu.Lock() defer r.mu.Unlock() if len(cs) == 0 { - return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + return fmt.Errorf("can't register hierarchy with no controllers") } for _, c := range cs { if _, ok := r.controllers[c.Type()]; ok { - return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + return fmt.Errorf("controllers may only be mounted on a single hierarchy") } } hid, err := r.nextHierarchyID() if err != nil { - return hid, err + return err } + // Must not fail below here, once we publish the hierarchy ID. + + fs.InitializeHierarchyID(hid) + h := hierarchy{ id: hid, controllers: make(map[CgroupControllerType]CgroupController), - fs: cs[0].Filesystem(), + fs: fs.VFSFilesystem(), } for _, c := range cs { n := c.Type() @@ -214,7 +227,7 @@ func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { h.controllers[n] = c } r.hierarchies[hid] = h - return hid, nil + return nil } // Unregister removes a previously registered hierarchy from the registry. If @@ -253,6 +266,11 @@ func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[C for name, ctl := range r.controllers { if _, ok := ctlSet[name]; !ok { cg := ctl.RootCgroup() + // Multiple controllers may share the same hierarchy, so may have + // the same root cgroup. Grab a single ref per hierarchy root. + if _, ok := cgset[cg]; ok { + continue + } cg.IncRef() // Ref transferred to caller. cgset[cg] = struct{}{} } diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 10885688c..62777faa8 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -154,9 +154,11 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. - err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) - if err != nil && err != syserror.ENOLCK { - panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + if file.SupportsLocks() { + err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) + if err != nil && err != syserror.ENOLCK { + panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + } } // Drop the table's reference. diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 36855e3ec..601fc0d3a 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -30,8 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") - // SyscallRestartBlock represents the restart block for a syscall restartable // with a custom function. It encapsulates the state required to restart a // syscall across a S/R. @@ -284,7 +282,6 @@ func (*runSyscallExit) execute(t *Task) taskRunState { // indicated by an execution fault at address addr. doVsyscall returns the // task's next run state. func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState { - vsyscallCount.Increment() metric.WeirdnessMetric.Increment("vsyscall_count") // Grab the caller up front, to make sure there's a sensible stack. diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index ecb6603a1..4c65215fa 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -11,11 +11,12 @@ go_library( "vdso.go", "vdso_state.go", ], + marshal = True, + marshal_debug = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/cpuid", "//pkg/hostarch", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index e92d9fdc3..8fc3e2a79 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -47,10 +46,10 @@ const ( var ( // header64Size is the size of elf.Header64. - header64Size = int(binary.Size(elf.Header64{})) + header64Size = (*linux.ElfHeader64)(nil).SizeBytes() // Prog64Size is the size of elf.Prog64. - prog64Size = int(binary.Size(elf.Prog64{})) + prog64Size = (*linux.ElfProg64)(nil).SizeBytes() ) func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType { @@ -136,7 +135,6 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Unsupported ELF endianness: %v", endian) return elfInfo{}, syserror.ENOEXEC } - byteOrder := binary.LittleEndian if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT { log.Infof("Unsupported ELF version: %v", version) @@ -145,7 +143,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // EI_OSABI is ignored by Linux, which is the only OS supported. os := abi.Linux - var hdr elf.Header64 + var hdr linux.ElfHeader64 hdrBuf := make([]byte, header64Size) _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0) if err != nil { @@ -156,7 +154,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { } return elfInfo{}, err } - binary.Unmarshal(hdrBuf, byteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBuf) // We support amd64 and arm64. var a arch.Arch @@ -213,8 +211,8 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { - var prog64 elf.Prog64 - binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64) + var prog64 linux.ElfProg64 + prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) phdrBuf = phdrBuf[prog64Size:] phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 99f036bba..1b5d5f66e 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -75,6 +75,9 @@ type machine struct { // nextID is the next vCPU ID. nextID uint32 + + // machineArchState is the architecture-specific state. + machineArchState } const ( @@ -196,12 +199,7 @@ func newMachine(vm int) (*machine, error) { m.available.L = &m.mu // Pull the maximum vCPUs. - maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) - if errno != 0 { - m.maxVCPUs = _KVM_NR_VCPUS - } else { - m.maxVCPUs = int(maxVCPUs) - } + m.getMaxVCPU() log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) @@ -427,9 +425,8 @@ func (m *machine) Get() *vCPU { } } - // Create a new vCPU (maybe). - if int(m.nextID) < m.maxVCPUs { - c := m.newVCPU() + // Get a new vCPU (maybe). + if c := m.getNewVCPU(); c != nil { c.lock() m.vCPUsByTID[tid] = c m.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index d7abfefb4..9a2337654 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -63,6 +63,9 @@ func (m *machine) initArchState() error { return nil } +type machineArchState struct { +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -351,6 +354,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) // allocations occur. entersyscall() bluepill(c) + // The root table physical page has to be mapped to not fault in iret + // or sysret after switching into a user address space. sysret and + // iret are in the upper half that is global and already mapped. + switchOpts.PageTables.PrefaultRootTable() prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -495,3 +502,22 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { physical) } } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + if errno != 0 { + m.maxVCPUs = _KVM_NR_VCPUS + } else { + m.maxVCPUs = int(maxVCPUs) + } +} + +// getNewVCPU create a new vCPU (maybe) +func (m *machine) getNewVCPU() *vCPU { + if int(m.nextID) < m.maxVCPUs { + c := m.newVCPU() + return c + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index cd912f922..8926b1d9f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,6 +17,10 @@ package kvm import ( + "runtime" + "sync/atomic" + + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -25,6 +29,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" ) +type machineArchState struct { + //initialvCPUs is the machine vCPUs which has initialized but not used + initialvCPUs map[int]*vCPU +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -182,3 +191,30 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return accessType, platform.ErrContextSignal } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + rmaxVCPUs := runtime.NumCPU() + smaxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + // compare the max vcpu number from runtime and syscall, use smaller one. + if errno != 0 { + m.maxVCPUs = rmaxVCPUs + } else { + if rmaxVCPUs < int(smaxVCPUs) { + m.maxVCPUs = rmaxVCPUs + } else { + m.maxVCPUs = int(smaxVCPUs) + } + } +} + +// getNewVCPU() scan for an available vCPU from initialvCPUs +func (m *machine) getNewVCPU() *vCPU { + for CID, c := range m.initialvCPUs { + if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { + delete(m.initialvCPUs, CID) + return c + } + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 634e55ec0..92edc992b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" + ktime "gvisor.dev/gvisor/pkg/sentry/time" ) type kvmVcpuInit struct { @@ -47,6 +48,19 @@ func (m *machine) initArchState() error { uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) } + + // Initialize all vCPUs on ARM64, while this does not happen on x86_64. + // The reason for the difference is that ARM64 and x86_64 have different KVM timer mechanisms. + // If we create vCPU dynamically on ARM64, the timer for vCPU would mess up for a short time. + // For more detail, please refer to https://github.com/google/gvisor/issues/5739 + m.initialvCPUs = make(map[int]*vCPU) + m.mu.Lock() + for int(m.nextID) < m.maxVCPUs-1 { + c := m.newVCPU() + c.state = 0 + m.initialvCPUs[c.id] = c + } + m.mu.Unlock() return nil } @@ -174,9 +188,58 @@ func (c *vCPU) setTSC(value uint64) error { return nil } +// getTSC gets the counter Physical Counter minus Virtual Offset. +func (c *vCPU) getTSC() error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + + if err := c.getOneRegister(®); err != nil { + return err + } + + return nil +} + // setSystemTime sets the vCPU to the system time. func (c *vCPU) setSystemTime() error { - return c.setSystemTimeLegacy() + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Use get the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + // replace getTSC to another setOneRegister syscall can get more accurate value? + start := uint64(ktime.Rdtsc()) + if err := c.getTSC(); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && (current <= upperThreshold || minimum < 50) { + // Try to set the TSC + if err := c.setTSC(end + (minimum / 2)); err != nil { + return err + } + return nil + } + } } //go:nosplit @@ -203,7 +266,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { uintptr(c.fd), _KVM_GET_ONE_REG, uintptr(unsafe.Pointer(reg))); errno != 0 { - return fmt.Errorf("error setting one register: %v", errno) + return fmt.Errorf("error getting one register: %v", errno) } return nil } diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 080859125..7ee89a735 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 0e0e82365..2029e7cf4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -14,9 +14,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 45a05cd63..235b9c306 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -18,9 +18,11 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := binary.AlignDown(cap(buf)-len(buf), 4) + space := bits.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. @@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } @@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = putUint32(buf, msgType) hdrBuf := buf - - buf = binary.Marshal(buf, hostarch.ByteOrder, data) + buf = append(buf, marshal.Marshal(data)...) // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { @@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := binary.AlignUp(len(buf), align) + aligned := bits.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte { // PackTimestamp packs a SO_TIMESTAMP socket control message. func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp) return putCmsgStruct( buf, linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), - linux.NsecToTimeval(timestamp), + ×tampP, ) } @@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { linux.SOL_TCP, linux.TCP_INQ, t.Arch().Width(), - inq, + primitive.AllocateInt32(inq), ) } @@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { linux.SOL_IP, linux.IP_TOS, t.Arch().Width(), - tos, + primitive.AllocateUint8(tos), ) } @@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { linux.SOL_IPV6, linux.IPV6_TCLASS, t.Arch().Width(), - tClass, + primitive.AllocateUint32(tClass), ) } @@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt // cmsgSpace is equivalent to CMSG_SPACE in Linux. func cmsgSpace(t *kernel.Task, dataLen int) int { - return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) + return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width()) } // CmsgsSpace returns the number of bytes needed to fit the control messages @@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, syserror.EINVAL @@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SOL_SOCKET: switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { @@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, syserror.EINVAL } var ts linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts) + ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) cmsgs.IP.Timestamp = ts.ToNsecCapped() cmsgs.IP.HasTimestamp = true - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: // Unknown message type. @@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTOS = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS) - i += binary.AlignUp(length, width) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + cmsgs.IP.TOS = uint8(tos) + i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) cmsgs.IP.PacketInfo = packetInfo - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL @@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTClass = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass) - i += binary.AlignUp(length, width) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + cmsgs.IP.TClass = uint32(tclass) + i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a5c2155a2..3c6511ead 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -17,7 +17,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fdnotifier", "//pkg/hostarch", @@ -40,8 +39,6 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d3b23643..52ae4bc9c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" @@ -529,7 +528,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalBytes(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) controlMessages.IP.Timestamp = ts.ToNsecCapped() } @@ -537,17 +536,19 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IP_TOS: controlMessages.IP.HasTOS = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -560,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -577,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.TCP_INQ: controlMessages.IP.HasInq = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq) + var inq primitive.Int32 + inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + controlMessages.IP.Inq = int32(inq) } } } @@ -691,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 { return 0 } - binary.Unmarshal(buf, hostarch.ByteOrder, &info) + info.UnmarshalUnsafe(buf[:info.SizeBytes()]) return uint32(info.State) } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 26e8ae17a..cbb1e905d 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -15,6 +15,7 @@ package hostinet import ( + "encoding/binary" "fmt" "io" "io/ioutil" @@ -26,16 +27,14 @@ import ( "syscall" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -65,8 +64,6 @@ type Stack struct { tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File - ipv4Forwarding bool - ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -126,13 +123,6 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } - s.ipv6Forwarding = false - if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { - s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" - } else { - log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") - } - return nil } @@ -147,8 +137,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(link.Data) < unix.SizeofIfInfomsg { return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } - var ifinfo unix.IfInfomsg - binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo) + var ifinfo linux.InterfaceInfoMessage + ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -178,11 +168,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(addr.Data) < unix.SizeofIfAddrmsg { return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } - var ifaddr unix.IfAddrmsg - binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr) + var ifaddr linux.InterfaceAddrMessage + ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, - PrefixLen: ifaddr.Prefixlen, + PrefixLen: ifaddr.PrefixLen, Flags: ifaddr.Flags, } attrs, err := syscall.ParseNetlinkRouteAttr(&addr) @@ -210,13 +200,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) continue } - var ifRoute unix.RtMsg - binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute) + var ifRoute linux.RouteMessage + ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) inetRoute := inet.Route{ Family: ifRoute.Family, - DstLen: ifRoute.Dst_len, - SrcLen: ifRoute.Src_len, - TOS: ifRoute.Tos, + DstLen: ifRoute.DstLen, + SrcLen: ifRoute.SrcLen, + TOS: ifRoute.TOS, Table: ifRoute.Table, Protocol: ifRoute.Protocol, Scope: ifRoute.Scope, @@ -245,7 +235,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) if len(attr.Value) != expected { return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) } - binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface) + var outputIF primitive.Int32 + outputIF.UnmarshalUnsafe(attr.Value) + inetRoute.OutputInterface = int32(outputIF) } } @@ -489,19 +481,6 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// Forwarding implements inet.Stack.Forwarding. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - switch protocol { - case ipv4.ProtocolNumber: - return s.ipv4Forwarding - case ipv6.ProtocolNumber: - return s.ipv6Forwarding - default: - log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) - return false - } -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 4381dfa06..61b2c9755 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -14,14 +14,16 @@ go_library( "tcp_matcher.go", "udp_matcher.go", ], + marshal = True, # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/hostarch", "//pkg/log", + "//pkg/marshal", "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 4bd305a44..6fc7781ad 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte { nflog("marshaling matcher %q", name) // We have to pad this struct size to a multiple of 8 bytes. - size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) + size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) matcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: uint16(size), @@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte { } copy(matcher.Name[:], name) - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, hostarch.ByteOrder, matcher) - return append(buf, make([]byte, size-len(buf))...) + buf := make([]byte, size) + entryLen := matcher.XTEntryMatch.SizeBytes() + matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) + copy(buf[entryLen:], matcher.Data) + return buf } func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 1fc4cb651..cb78ef60b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 67a52b628..5cb7fe4aa 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IP6TEntry - buf := optVal[:linux.SizeOfIP6TEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIP6TEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index c6fa3fd16..f42d73178 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -22,7 +22,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } @@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } @@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace) + replace.UnmarshalBytes(replaceBuf) // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table @@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:linux.SizeOfXTEntryMatch] - binary.Unmarshal(buf, hostarch.ByteOrder, &match) + buf := optVal[:match.SizeBytes()] + match.UnmarshalUnsafe(buf) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index b2cc6be20..60845cab3 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } } - buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) - return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo)) + buf := marshal.Marshal(&iptOwnerInfo) + return marshalEntryMatch(matcherNameOwner, buf) } // unmarshal implements matchMaker.unmarshal. @@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 4ae1592b2..fa5456eee 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,12 @@ package netfilter import ( + "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -189,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte { Verdict: verdict, } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -199,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - buf = buf[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget) + standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -245,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte { copy(xt.Name[:], errorName) copy(xt.Target.Name[:], ErrorTargetName) - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -256,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt) + errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: // * An actual error case. These rules have an error named @@ -299,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte { } copy(xt.Target.Name[:], RedirectTargetName) - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -320,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &rt) + rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. target := redirectTarget{RedirectTarget: stack.RedirectTarget{ @@ -359,6 +356,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return &target, nil } +// +marshal type nfNATTarget struct { Target linux.XTEntryTarget Range linux.NFNATRange @@ -394,8 +392,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte { nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarshalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -411,7 +408,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar var natRange linux.NFNATRange buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + natRange.UnmarshalUnsafe(buf) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -468,8 +465,7 @@ func (*snatTargetMakerV4) marshal(target target) []byte { xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr) copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr) - ret := make([]byte, 0, linux.SizeOfXTSNATTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -485,7 +481,7 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta var st linux.XTSNATTarget buf = buf[:linux.SizeOfXTSNATTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &st) + st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. target := snatTarget{SNATTarget: stack.SNATTarget{ @@ -550,8 +546,7 @@ func (*snatTargetMakerV6) marshal(target target) []byte { nt.Range.MinProto = htons(st.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarshalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -567,9 +562,9 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta var natRange linux.NFNATRange buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + natRange.UnmarshalUnsafe(buf) - // TODO(gvisor.dev/issue/5689): Support port or address ranges. + // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") return nil, syserr.ErrInvalidArgument @@ -631,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &target) + target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 69557f515..95bb9826e 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTTCP) - return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp)) + return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 6a60e6bd6..fb8be27e6 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTUDP) - return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp)) + return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 171b95c63..64cd263da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,7 +14,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", @@ -50,5 +50,7 @@ go_test( deps = [ ":netlink", "//pkg/abi/linux", + "//pkg/marshal", + "//pkg/marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ab0e68af7..80385bfdc 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -19,15 +19,17 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" ) // alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. func alignPad(length int, align uint) int { - return binary.AlignUp(length, align) - length + return bits.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. @@ -42,7 +44,7 @@ type Message struct { func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ hdr: hdr, - buf: binary.Marshal(nil, hostarch.ByteOrder, hdr), + buf: marshal.Marshal(&hdr), } } @@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { return } var hdr linux.NetlinkMessageHeader - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) // Msg portion. totalMsgLen := int(hdr.Length) @@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader { // GetData unmarshals the payload message header from this netlink message, and // returns the attributes portion. -func (m *Message) GetData(msg interface{}) (AttrsView, bool) { +func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) { b := BytesView(m.buf) _, ok := b.Extract(linux.NetlinkMessageHeaderSize) @@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { return nil, false } - size := int(binary.Size(msg)) + size := msg.SizeBytes() msgBytes, ok := b.Extract(size) if !ok { return nil, false } - binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg) + msg.UnmarshalUnsafe(msgBytes) numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) // Linux permits the last message not being aligned, just consume all of it. @@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) { } // Put serializes v into the message. -func (m *Message) Put(v interface{}) { - m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v) +func (m *Message) Put(v marshal.Marshallable) { + m.buf = append(m.buf, marshal.Marshal(v)...) } // PutAttr adds v to the message as a netlink attribute. // // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize + -// binary.Size(v) fits in math.MaxUint16 bytes. -func (m *Message) PutAttr(atype uint16, v interface{}) { - l := linux.NetlinkAttrHeaderSize + int(binary.Size(v)) +// v.SizeBytes()) fits in math.MaxUint16 bytes. +func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) { + l := linux.NetlinkAttrHeaderSize + v.SizeBytes() if l > math.MaxUint16 { panic(fmt.Sprintf("attribute too large: %d", l)) } - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) m.Put(v) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } // PutAttrString adds s to the message as a netlink attribute. func (m *Message) PutAttrString(atype uint16, s string) { l := linux.NetlinkAttrHeaderSize + len(s) + 1 - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) // String + NUL-termination. - m.Put([]byte(s)) + m.Put(primitive.AsByteSlice([]byte(s))) m.putZeros(1) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest if !ok { return } - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index ef13d9386..968968469 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -20,13 +20,31 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) type dummyNetlinkMsg struct { + marshal.StubMarshallable Foo uint16 } +func (*dummyNetlinkMsg) SizeBytes() int { + return 2 +} + +func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { + p := primitive.Uint16(m.Foo) + p.MarshalUnsafe(dst) +} + +func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { + var p primitive.Uint16 + p.UnmarshalUnsafe(src) + m.Foo = uint16(p) +} + func TestParseMessage(t *testing.T) { tests := []struct { desc string diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 744fc74f4..c6c04b4e3 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 5a2255db3..86f6419dc 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { Type: linux.RTM_NEWLINK, }) - m.Put(linux.InterfaceInfoMessage{ + m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, Index: idx, @@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { }) m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) + m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU)) mac := make([]byte, 6) brd := mac @@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { mac = i.Addr brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac)) + m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd)) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl Type: linux.RTM_NEWADDR, }) - m.Put(linux.InterfaceAddrMessage{ + m.Put(&linux.InterfaceAddrMessage{ Family: a.Family, PrefixLen: a.PrefixLen, Index: uint32(id), }) - m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) - m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) + addr := primitive.ByteSlice([]byte(a.Addr)) + m.PutAttr(linux.IFA_LOCAL, &addr) + m.PutAttr(linux.IFA_ADDRESS, &addr) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Type: linux.RTM_NEWROUTE, }) - m.Put(linux.RouteMessage{ + m.Put(&linux.RouteMessage{ Family: rt.Family, DstLen: rt.DstLen, SrcLen: rt.SrcLen, @@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Flags: rt.Flags, }) - m.PutAttr(254, []byte{123}) + m.PutAttr(254, primitive.AsByteSlice([]byte{123})) if rt.DstLen > 0 { - m.PutAttr(linux.RTA_DST, rt.DstAddr) + m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr)) } if rt.SrcLen > 0 { - m.PutAttr(linux.RTA_SRC, rt.SrcAddr) + m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr)) } if rt.OutputInterface != 0 { - m.PutAttr(linux.RTA_OIF, rt.OutputInterface) + m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface)) } if len(rt.GatewayAddr) > 0 { - m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr) + m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr)) } // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms hdr := msg.Header() // All messages start with a 1 byte protocol family. - var family uint8 + var family primitive.Uint8 if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 30c297149..280563d09 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,7 +20,6 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa) + sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument @@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - sendBufferSizeP := primitive.Int32(s.sendBufferSize) - return &sendBufferSizeP, nil + return primitive.AllocateInt32(int32(s.sendBufferSize)), nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - recvBufferSizeP := primitive.Int32(math.MaxInt32) - return &recvBufferSizeP, nil + return primitive.AllocateInt32(math.MaxInt32), nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * Family: linux.AF_NETLINK, PortID: uint32(s.portID), } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // GetPeerName implements socket.Socket.GetPeerName. @@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * // must be the kernel. PortID: 0, } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // RecvMsg implements socket.Socket.RecvMsg. @@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Family: linux.AF_NETLINK, PortID: 0, } - fromLen := uint32(binary.Size(from)) + fromLen := uint32(from.SizeBytes()) trunc := flags&linux.MSG_TRUNC != 0 @@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys }) // Add the dump_done_errno payload. - m.Put(int64(0)) + m.Put(primitive.AllocateInt64(0)) _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { @@ -658,8 +655,8 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ - Error: int32(-err.ToLinux().Number()), + m.Put(&linux.NetlinkErrorMessage{ + Error: int32(-err.ToLinux()), Header: hdr, }) } @@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: 0, Header: hdr, }) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0b39a5b67..9561b7c25 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,7 +19,6 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 312f5f85a..0b64a24c3 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "encoding/binary" "fmt" "io" "io/ioutil" @@ -35,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -199,6 +199,15 @@ var Metrics = tcpip.Stats{ OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."), OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), + Forwarding: tcpip.IPForwardingStats{ + Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), + ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), + LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), + LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), + ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), + PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not fit within the outgoing MTU."), + Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), + }, }, ARP: tcpip.ARPStats{ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), @@ -375,9 +384,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue }), nil } -var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) -var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) -var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) +var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() +var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() +var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -613,7 +622,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) if a.Protocol != uint16(s.protocol) { return syserr.ErrInvalidArgument @@ -843,7 +852,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &optP, nil } - optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux()) return &optP, nil case linux.SO_PEERCRED: @@ -1312,7 +1321,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &v, nil case linux.IP6T_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet6{})) { + if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument } @@ -1509,7 +1518,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return &v, nil case linux.SO_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet{})) { + if outLen < sockAddrInetSize { return nil, syserr.ErrInvalidArgument } @@ -1742,7 +1751,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1755,7 +1764,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1791,7 +1800,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + + if v != (linux.Linger{}) { + socket.SetSockOptEmitUnimplementedEvent(t, name) + } ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, @@ -2090,9 +2103,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } var ( - inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{})) - inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{})) - inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{})) + inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes() + inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes() + inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes() ) // copyInMulticastRequest copies in a variable-size multicast request. The @@ -2117,12 +2130,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) return req, nil } var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest) + req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) return req, nil } @@ -2132,7 +2145,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) return req, nil } @@ -3101,8 +3114,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe continue } // Populate ifr.ifr_netmask (type sockaddr). - hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) - hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0) var mask uint32 = 0xffffffff << (32 - addr.PrefixLen) // Netmask is expected to be returned as a big endian // value. diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index b215067cf..eef5e6519 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -458,23 +458,10 @@ func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { s.Stack.RestoreCleanupEndpoints(es) } -// Forwarding implements inet.Stack.Forwarding. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - return s.Stack.Forwarding(protocol) - default: - panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) - } -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - s.Stack.SetForwarding(protocol, enable) - default: - panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil { + return fmt.Errorf("SetForwardingDefaultAndAllNICs(%d, %t): %s", protocol, enable, err) } return nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 4c3d48096..353f4ade0 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,7 +24,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -81,7 +80,7 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { } ee := linux.SockExtendedErr{ - Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux()), Origin: errOriginToLinux(sockErr.Cause.Origin()), Type: sockErr.Cause.Type(), Code: sockErr.Cause.Code(), @@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInetSize]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2ebd77f82..1fbbd133c 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -25,7 +25,6 @@ go_library( ":strace_go_proto", "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", "//pkg/hostarch", diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 71b92eaee..d66befe81 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index bd7361a52..1a2d7d75f 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index e5b7f9b96..f4aab25b0 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -20,14 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - - "gvisor.dev/gvisor/pkg/hostarch" ) // SocketFamily are the possible socket(2) families. @@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } +func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { + count := len(src) / linux.SizeOfControlMessageRight + cmr := make(linux.ControlMessageRights, count) + for i, _ := range cmr { + cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) + } + return cmr +} + func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string { if length > maxBytes { return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length) @@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) var skipData bool level := "SOL_SOCKET" @@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) continue } switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) - - numRights := rightsSize / linux.SizeOfControlMessageRight - fds := make(linux.ControlMessageRights, numRights) - binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds) - + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) rights := make([]string, 0, len(fds)) for _, fd := range fds { rights = append(rights, fmt.Sprint(fd)) @@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", @@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var tv linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv) + tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", @@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) default: panic("unreachable") } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index e115683f8..3b4d79889 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // WaitEpoll implements the epoll_wait(2) linux syscall. -func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) { +func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) { // Get epoll from the file descriptor. epollfile := t.GetFile(fd) if epollfile == nil { @@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // Try to read events and return right away if we got them or if the // caller requested a non-blocking "wait". r := e.ReadEvents(max) - if len(r) != 0 || timeout == 0 { + if len(r) != 0 || timeoutInNanos == 0 { return r, nil } @@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // and register with the epoll object for readability events. var haveDeadline bool var deadline ktime.Time - if timeout > 0 { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index c668e81ac..6eabfd219 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -29,8 +29,7 @@ import ( ) var ( - partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.") - partialResultOnce sync.Once + partialResultOnce sync.Once ) // incrementPartialResultMetric increments PartialResultMetric by calling @@ -38,7 +37,6 @@ var ( // us to pass a function which does not take any arguments, whereas Increment() // takes a variadic number of arguments. func incrementPartialResultMetric() { - partialResultMetric.Increment() metric.WeirdnessMetric.Increment("partial_result") } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 2d2212605..090c5ffcb 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{ 0xffffffffff600000: 96, // vsyscall gettimeofday(2) @@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 7f460d30b..69cbc98d0 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" @@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements the epoll_wait(2) linux syscall. -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - - r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout) +func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { + r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos) if err != nil { return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) } @@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return uintptr(len(r)), nil, nil + +} + +// EpollWait implements the epoll_wait(2) linux syscall. +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + // Convert milliseconds to nanoseconds. + timeoutInNanos := int64(args[3].Int()) * 1000000 + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements the epoll_pwait(2) linux syscall. @@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } +// EpollPwait2 implements the epoll_pwait(2) linux syscall. +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + timeout, err := copyTimespecIn(t, timeoutPtr) + if err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + + } + + if maskAddr != 0 { + mask, err := CopyInSigSet(t, maskAddr, maskSize) + if err != nil { + return 0, nil, err + } + + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} + // LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 5e9e940df..e07917613 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -463,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index b980aa43e..047d955b6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements Linux syscall epoll_wait(2). -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - +func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { return 0, nil, syserror.EINVAL @@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return 0, nil, err } - if timeout == 0 { + if timeoutInNanos == 0 { return 0, nil, nil } // In the first iteration of this loop, register with the epoll @@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer epfile.EventUnregister(&w) } else { // Set up the timer if a timeout was specified. - if timeout > 0 && !haveDeadline { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 && !haveDeadline { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } @@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } } } + +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutInNanos := int64(args[3].Int()) * 1000000 + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements Linux syscall epoll_pwait(2). @@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// EpollPwait2 implements Linux syscall epoll_pwait(2). +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + var timeout linux.Timespec + if _, err := timeout.CopyIn(t, timeoutPtr); err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 6edde0ed1..69f69e3af 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -467,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c50fd97eb..0fc81e694 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -159,6 +159,7 @@ func Override() { s.Table[327] = syscalls.Supported("preadv2", Preadv2) s.Table[328] = syscalls.Supported("pwritev2", Pwritev2) s.Table[332] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() // Override ARM64. @@ -269,6 +270,7 @@ func Override() { s.Table[286] = syscalls.Supported("preadv2", Preadv2) s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) s.Table[291] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() } diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 94f98d746..39bf1e0de 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -25,11 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// fallbackMetric tracks failed updates. It is not sync, as it is not critical -// that all occurrences are captured and CalibratedClock may fallback many -// times. -var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update") - // CalibratedClock implements a clock that tracks a reference clock. // // Users should call Update at regular intervals of around approxUpdateInterval @@ -102,8 +97,7 @@ func (c *CalibratedClock) resetLocked(str string, v ...interface{}) { c.Warningf(str+" Resetting clock; time may jump.", v...) c.ready = false c.ref.Reset() - fallbackMetric.Increment() - metric.WeirdnessMetric.Increment("fallback") + metric.WeirdnessMetric.Increment("time_fallback") } // updateParams updates the timekeeping parameters based on the passed diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f612a71b2..ef8d8a813 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -454,6 +454,9 @@ type FileDescriptionImpl interface { // RemoveXattr removes the given extended attribute from the file. RemoveXattr(ctx context.Context, name string) error + // SupportsLocks indicates whether file locks are supported. + SupportsLocks() bool + // LockBSD tries to acquire a BSD-style advisory file lock. LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error @@ -524,7 +527,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } return fd.impl.Stat(ctx, opts) @@ -539,7 +542,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetStat(ctx, opts) @@ -555,7 +558,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } return fd.impl.StatFS(ctx) @@ -701,7 +704,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string Start: fd.vd, }) names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, err } names, err := fd.impl.ListXattr(ctx, size) @@ -730,7 +733,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) Start: fd.vd, }) val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, err } return fd.impl.GetXattr(ctx, *opts) @@ -746,7 +749,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetXattr(ctx, *opts) @@ -762,7 +765,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { Start: fd.vd, }) err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.RemoveXattr(ctx, name) @@ -818,6 +821,11 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e return fd.Sync(ctx) } +// SupportsLocks indicates whether file locks are supported. +func (fd *FileDescription) SupportsLocks() bool { + return fd.impl.SupportsLocks() +} + // LockBSD tries to acquire a BSD-style advisory file lock. func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error { atomic.StoreUint32(&fd.usedLockBSD, 1) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index b87d9690a..2b6f47b4b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -413,6 +413,11 @@ type LockFD struct { locks *FileLocks } +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (LockFD) SupportsLocks() bool { + return true +} + // Init initializes fd with FileLocks to use. func (fd *LockFD) Init(locks *FileLocks) { fd.locks = locks @@ -423,28 +428,28 @@ func (fd *LockFD) Locks() *FileLocks { return fd.locks } -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +// LockBSD implements FileDescriptionImpl.LockBSD. func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.locks.LockBSD(ctx, uid, ownerPID, t, block) } -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { fd.locks.UnlockBSD(uid) return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block) } -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return fd.locks.UnlockPOSIX(ctx, uid, r) } -// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { return fd.locks.TestPOSIX(ctx, uid, t, r) } @@ -455,27 +460,68 @@ func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.L // +stateify savable type NoLockFD struct{} -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (NoLockFD) SupportsLocks() bool { + return false +} + +// LockBSD implements FileDescriptionImpl.LockBSD. func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return syserror.ENOLCK } -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return syserror.ENOLCK } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return syserror.ENOLCK } -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return syserror.ENOLCK } -// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { return linux.Flock{}, syserror.ENOLCK } + +// BadLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface +// returning EBADF. +// +// +stateify savable +type BadLockFD struct{} + +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (BadLockFD) SupportsLocks() bool { + return false +} + +// LockBSD implements FileDescriptionImpl.LockBSD. +func (BadLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return syserror.EBADF +} + +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. +func (BadLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + return syserror.EBADF +} + +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. +func (BadLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return syserror.EBADF +} + +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. +func (BadLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return syserror.EBADF +} + +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. +func (BadLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return linux.Flock{}, syserror.EBADF +} diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go index 39fbac987..e9651b631 100644 --- a/pkg/sentry/vfs/opath.go +++ b/pkg/sentry/vfs/opath.go @@ -24,96 +24,96 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH. +// opathFD implements FileDescriptionImpl for a file description opened with O_PATH. // // +stateify savable type opathFD struct { vfsfd FileDescription FileDescriptionDefaultImpl - NoLockFD + BadLockFD } -// Release implements vfs.FileDescriptionImpl.Release. +// Release implements FileDescriptionImpl.Release. func (fd *opathFD) Release(context.Context) { // noop } -// Allocate implements vfs.FileDescriptionImpl.Allocate. +// Allocate implements FileDescriptionImpl.Allocate. func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.EBADF } -// PRead implements vfs.FileDescriptionImpl.PRead. +// PRead implements FileDescriptionImpl.PRead. func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.EBADF } -// Read implements vfs.FileDescriptionImpl.Read. +// Read implements FileDescriptionImpl.Read. func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { return 0, syserror.EBADF } -// PWrite implements vfs.FileDescriptionImpl.PWrite. +// PWrite implements FileDescriptionImpl.PWrite. func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.EBADF } -// Write implements vfs.FileDescriptionImpl.Write. +// Write implements FileDescriptionImpl.Write. func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { return 0, syserror.EBADF } -// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { return 0, syserror.EBADF } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +// IterDirents implements FileDescriptionImpl.IterDirents. func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error { return syserror.EBADF } -// Seek implements vfs.FileDescriptionImpl.Seek. +// Seek implements FileDescriptionImpl.Seek. func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { return 0, syserror.EBADF } -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap. func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.EBADF } -// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +// ListXattr implements FileDescriptionImpl.ListXattr. func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { return nil, syserror.EBADF } -// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +// GetXattr implements FileDescriptionImpl.GetXattr. func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { return "", syserror.EBADF } -// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +// SetXattr implements FileDescriptionImpl.SetXattr. func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error { return syserror.EBADF } -// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +// RemoveXattr implements FileDescriptionImpl.RemoveXattr. func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error { return syserror.EBADF } -// Sync implements vfs.FileDescriptionImpl.Sync. +// Sync implements FileDescriptionImpl.Sync. func (fd *opathFD) Sync(ctx context.Context) error { return syserror.EBADF } -// SetStat implements vfs.FileDescriptionImpl.SetStat. +// SetStat implements FileDescriptionImpl.SetStat. func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error { return syserror.EBADF } -// Stat implements vfs.FileDescriptionImpl.Stat. +// Stat implements FileDescriptionImpl.Stat. func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { vfsObj := fd.vfsfd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ @@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err Start: fd.vfsfd.vd, }) stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } @@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vfsfd.vd, }) statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index e4fd55012..97b898aba 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -44,13 +44,10 @@ type ResolvingPath struct { start *Dentry pit fspath.Iterator - flags uint16 - mustBeDir bool // final file must be a directory? - mustBeDirOrig bool - symlinks uint8 // number of symlinks traversed - symlinksOrig uint8 - curPart uint8 // index into parts - numOrigParts uint8 + flags uint16 + mustBeDir bool // final file must be a directory? + symlinks uint8 // number of symlinks traversed + curPart uint8 // index into parts creds *auth.Credentials @@ -60,14 +57,9 @@ type ResolvingPath struct { nextStart *Dentry // ref held if not nil absSymlinkTarget fspath.Path - // ResolvingPath must track up to two relative paths: the "current" - // relative path, which is updated whenever a relative symlink is - // encountered, and the "original" relative path, which is updated from the - // current relative path by handleError() when resolution must change - // filesystems (due to reaching a mount boundary or absolute symlink) and - // overwrites the current relative path when Restart() is called. - parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator - origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator + // ResolvingPath tracks relative paths, which is updated whenever a relative + // symlink is encountered. + parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator } const ( @@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{ }, } +// getResolvingPath gets a new ResolvingPath from the pool. Caller must call +// ResolvingPath.Release() when done. func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath { rp := resolvingPathPool.Get().(*ResolvingPath) rp.vfs = vfs @@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat rp.flags |= rpflagsFollowFinalSymlink } rp.mustBeDir = pop.Path.Dir - rp.mustBeDirOrig = pop.Path.Dir rp.symlinks = 0 rp.curPart = 0 - rp.numOrigParts = 1 rp.creds = creds rp.parts[0] = pop.Path.Begin - rp.origParts[0] = pop.Path.Begin return rp } -func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { +// Copy creates another ResolvingPath with the same state as the original. +// Copies are independent, using the copy does not change the original and +// vice-versa. +// +// Caller must call Resease() when done. +func (rp *ResolvingPath) Copy() *ResolvingPath { + copy := resolvingPathPool.Get().(*ResolvingPath) + *copy = *rp // All fields all shallow copiable. + + // Take extra reference for the copy if the original had them. + if copy.flags&rpflagsHaveStartRef != 0 { + copy.start.IncRef() + } + if copy.flags&rpflagsHaveMountRef != 0 { + copy.mount.IncRef() + } + // Reset error state. + copy.nextStart = nil + copy.nextMount = nil + return copy +} + +// Release decrements references if needed and returns the object to the pool. +func (rp *ResolvingPath) Release(ctx context.Context) { rp.root = VirtualDentry{} rp.decRefStartAndMount(ctx) rp.mount = nil @@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() { } } -// Restart resets the stream of path components represented by rp to its state -// on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart(ctx context.Context) { - rp.pit = rp.origParts[rp.numOrigParts-1] - rp.mustBeDir = rp.mustBeDirOrig - rp.symlinks = rp.symlinksOrig - rp.curPart = rp.numOrigParts - 1 - copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState(ctx) -} - -func (rp *ResolvingPath) relpathCommit() { - rp.mustBeDirOrig = rp.mustBeDir - rp.symlinksOrig = rp.symlinks - rp.numOrigParts = rp.curPart + 1 - copy(rp.origParts[:rp.curPart], rp.parts[:]) - rp.origParts[rp.curPart] = rp.pit -} - // CheckRoot is called before resolving the parent of the Dentry d. If the // Dentry is contextually a VFS root, such that path resolution should treat // d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the @@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef rp.nextMount = nil rp.nextStart = nil - // Commit the previous FileystemImpl's progress through the relative - // path. (Don't consume the path component that caused us to traverse + // Don't consume the path component that caused us to traverse // through the mount root - i.e. the ".." - because we still need to - // resolve the mount point's parent in the new FilesystemImpl.) - rp.relpathCommit() + // resolve the mount point's parent in the new FilesystemImpl. + // // Restart path resolution on the new Mount. Don't bother calling // rp.releaseErrorState() since we already set nextMount and nextStart // to nil above. @@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.nextMount = nil // Consume the path component that represented the mount point. rp.Advance() - // Commit the previous FilesystemImpl's progress through the relative - // path. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true @@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.Advance() // Prepend the symlink target to the relative path. rp.relpathPrepend(rp.absSymlinkTarget) - // Commit the previous FilesystemImpl's progress through the relative - // path, including the symlink target we just prepended. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 00f1847d8..87fdcf403 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return vd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, err } } @@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return parentVD, name, nil } if checkInvariants { @@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, "", err } } @@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return nil } @@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return err } @@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp := vfs.getResolvingPath(creds, pop) if opts.Flags&linux.O_DIRECTORY != 0 { rp.mustBeDir = true - rp.mustBeDirOrig = true } // Ignore O_PATH for verity, as verity performs extra operations on the fd for verification. // The underlying filesystem that verity wraps opens the fd with O_PATH. @@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) if opts.FileExec { if fd.Mount().Flags.NoExec { @@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential return fd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return target, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return nil } @@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return err } @@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statx{}, err } } @@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statfs{}, err } } @@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return bep, nil } if checkInvariants { @@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede for { names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, nil } if err == syserror.ENOTSUP { @@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden for { val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden for { err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre for { err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 8e3146d8d..dfe85f31d 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -243,6 +243,7 @@ func (w *Watchdog) waitForStart() { } stuckStartup.Increment() + metric.WeirdnessMetric.Increment("watchdog_stuck_startup") var buf bytes.Buffer buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) @@ -312,10 +313,11 @@ func (w *Watchdog) runTurn() { // New stuck task detected. // // Note that tasks blocked doing IO may be considered stuck in kernel, - // unless they are surrounded b + // unless they are surrounded by // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() + metric.WeirdnessMetric.Increment("watchdog_stuck_tasks") newTaskFound = true } newOffenders[t] = tc diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD index 4f7c02f5d..fd6127b97 100644 --- a/pkg/shim/BUILD +++ b/pkg/shim/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -41,7 +41,19 @@ go_library( "@com_github_containerd_fifo//:go_default_library", "@com_github_containerd_typeurl//:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "shim_test", + size = "small", + srcs = ["service_test.go"], + library = ":shim", + deps = [ + "//pkg/shim/utils", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/service.go b/pkg/shim/service.go index 9d9fa8ef6..1f9adcb65 100644 --- a/pkg/shim/service.go +++ b/pkg/shim/service.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "time" @@ -44,6 +45,7 @@ import ( "github.com/containerd/containerd/sys/reaper" "github.com/containerd/typeurl" "github.com/gogo/protobuf/types" + specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" @@ -944,9 +946,19 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C if err != nil { return nil, fmt.Errorf("read oci spec: %w", err) } - if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + + updated, err := utils.UpdateVolumeAnnotations(spec) + if err != nil { return nil, fmt.Errorf("update volume annotations: %w", err) } + updated = updateCgroup(spec) || updated + + if updated { + if err := utils.WriteSpec(r.Bundle, spec); err != nil { + return nil, err + } + } + runsc.FormatRunscLogPath(r.ID, options.RunscConfig) runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) p := proc.New(r.ID, runtime, stdio.Stdio{ @@ -966,3 +978,39 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C p.Monitor = reaper.Default return p, nil } + +// updateCgroup updates cgroup path for the sandbox to make the sandbox join the +// pod cgroup and not the pause container cgroup. Returns true if the spec was +// modified. Ex.: +// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123 +// +func updateCgroup(spec *specs.Spec) bool { + if !utils.IsSandbox(spec) { + return false + } + if spec.Linux == nil || len(spec.Linux.CgroupsPath) == 0 { + return false + } + + // Search backwards for the pod cgroup path to make the sandbox use it, + // instead of the pause container's cgroup. + parts := strings.Split(spec.Linux.CgroupsPath, string(filepath.Separator)) + for i := len(parts) - 1; i >= 0; i-- { + if strings.HasPrefix(parts[i], "pod") { + var path string + for j := 0; j <= i; j++ { + path = filepath.Join(path, parts[j]) + } + // Add back the initial '/' that may have been lost above. + if filepath.IsAbs(spec.Linux.CgroupsPath) { + path = string(filepath.Separator) + path + } + if spec.Linux.CgroupsPath == path { + return false + } + spec.Linux.CgroupsPath = path + return true + } + } + return false +} diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go new file mode 100644 index 000000000..2d9f07e02 --- /dev/null +++ b/pkg/shim/service_test.go @@ -0,0 +1,121 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/shim/utils" +) + +func TestCgroupPath(t *testing.T) { + for _, tc := range []struct { + name string + path string + want string + }{ + { + name: "simple", + path: "foo/pod123/container", + want: "foo/pod123", + }, + { + name: "absolute", + path: "/foo/pod123/container", + want: "/foo/pod123", + }, + { + name: "no-container", + path: "foo/pod123", + want: "foo/pod123", + }, + { + name: "no-container-absolute", + path: "/foo/pod123", + want: "/foo/pod123", + }, + { + name: "double-pod", + path: "/foo/podium/pod123/container", + want: "/foo/podium/pod123", + }, + { + name: "start-pod", + path: "pod123/container", + want: "pod123", + }, + { + name: "start-pod-absolute", + path: "/pod123/container", + want: "/pod123", + }, + { + name: "slashes", + path: "///foo/////pod123//////container", + want: "/foo/pod123", + }, + { + name: "no-pod", + path: "/foo/nopod123/container", + want: "/foo/nopod123/container", + }, + } { + t.Run(tc.name, func(t *testing.T) { + spec := specs.Spec{ + Linux: &specs.Linux{ + CgroupsPath: tc.path, + }, + } + updated := updateCgroup(&spec) + if spec.Linux.CgroupsPath != tc.want { + t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath) + } + if shouldUpdate := tc.path != tc.want; shouldUpdate != updated { + t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate) + } + }) + } +} + +// Test cases that cgroup path should not be updated. +func TestCgroupNoUpdate(t *testing.T) { + for _, tc := range []struct { + name string + spec *specs.Spec + }{ + { + name: "empty", + spec: &specs.Spec{}, + }, + { + name: "subcontainer", + spec: &specs.Spec{ + Linux: &specs.Linux{ + CgroupsPath: "foo/pod123/container", + }, + Annotations: map[string]string{ + utils.ContainerTypeAnnotation: utils.ContainerTypeContainer, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if updated := updateCgroup(tc.spec); updated { + t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated) + } + }) + } +} diff --git a/pkg/shim/utils/annotations.go b/pkg/shim/utils/annotations.go index 1e9d3f365..c744800bb 100644 --- a/pkg/shim/utils/annotations.go +++ b/pkg/shim/utils/annotations.go @@ -19,7 +19,9 @@ package utils // These are vendor due to import conflicts. const ( sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" - containerTypeAnnotation = "io.kubernetes.cri.container-type" + // ContainerTypeAnnotation is they key that defines sandbox or container. + ContainerTypeAnnotation = "io.kubernetes.cri.container-type" containerTypeSandbox = "sandbox" - containerTypeContainer = "container" + // ContainerTypeContainer is the value for container. + ContainerTypeContainer = "container" ) diff --git a/pkg/shim/utils/utils.go b/pkg/shim/utils/utils.go index 7b1cd983e..f183b1bbc 100644 --- a/pkg/shim/utils/utils.go +++ b/pkg/shim/utils/utils.go @@ -18,19 +18,16 @@ package utils import ( "encoding/json" "io/ioutil" - "os" "path/filepath" specs "github.com/opencontainers/runtime-spec/specs-go" ) +const configFilename = "config.json" + // ReadSpec reads OCI spec from the bundle directory. func ReadSpec(bundle string) (*specs.Spec, error) { - f, err := os.Open(filepath.Join(bundle, "config.json")) - if err != nil { - return nil, err - } - b, err := ioutil.ReadAll(f) + b, err := ioutil.ReadFile(filepath.Join(bundle, configFilename)) if err != nil { return nil, err } @@ -41,9 +38,18 @@ func ReadSpec(bundle string) (*specs.Spec, error) { return &spec, nil } +// WriteSpec writes OCI spec to the bundle directory. +func WriteSpec(bundle string, spec *specs.Spec) error { + b, err := json.Marshal(spec) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, configFilename), b, 0666) +} + // IsSandbox checks whether a container is a sandbox container. func IsSandbox(spec *specs.Spec) bool { - t, ok := spec.Annotations[containerTypeAnnotation] + t, ok := spec.Annotations[ContainerTypeAnnotation] return !ok || t == containerTypeSandbox } diff --git a/pkg/shim/utils/volumes.go b/pkg/shim/utils/volumes.go index 52a428179..6bc75139d 100644 --- a/pkg/shim/utils/volumes.go +++ b/pkg/shim/utils/volumes.go @@ -15,9 +15,7 @@ package utils import ( - "encoding/json" "fmt" - "io/ioutil" "path/filepath" "strings" @@ -89,18 +87,16 @@ func isVolumePath(volume, path string) (bool, error) { } // UpdateVolumeAnnotations add necessary OCI annotations for gvisor -// volume optimization. -func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { - var ( - uid string - err error - ) +// volume optimization. Returns true if the spec was modified. +func UpdateVolumeAnnotations(s *specs.Spec) (bool, error) { + var uid string if IsSandbox(s) { + var err error uid, err = podUID(s) if err != nil { // Skip if we can't get pod UID, because this doesn't work // for containerd 1.1. - return nil + return false, nil } } var updated bool @@ -116,40 +112,48 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { // This is a sandbox. path, err := volumePath(volume, uid) if err != nil { - return fmt.Errorf("get volume path for %q: %w", volume, err) + return false, fmt.Errorf("get volume path for %q: %w", volume, err) } s.Annotations[volumeSourceKey(volume)] = path updated = true } else { // This is a container. for i := range s.Mounts { - // An error is returned for sandbox if source - // annotation is not successfully applied, so - // it is guaranteed that the source annotation - // for sandbox has already been successfully - // applied at this point. + // An error is returned for sandbox if source annotation is not + // successfully applied, so it is guaranteed that the source annotation + // for sandbox has already been successfully applied at this point. // - // The volume name is unique inside a pod, so - // matching without podUID is fine here. + // The volume name is unique inside a pod, so matching without podUID + // is fine here. // - // TODO: Pass podUID down to shim for containers to do - // more accurate matching. + // TODO: Pass podUID down to shim for containers to do more accurate + // matching. if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { - // gVisor requires the container mount type to match - // sandbox mount type. - s.Mounts[i].Type = v + // Container mount type must match the sandbox's mount type. + changeMountType(&s.Mounts[i], v) updated = true } } } } - if !updated { - return nil - } - // Update bundle. - b, err := json.Marshal(s) - if err != nil { - return err + return updated, nil +} + +func changeMountType(m *specs.Mount, newType string) { + m.Type = newType + + // OCI spec allows bind mounts to be specified in options only. So if new type + // is not bind, remove bind/rbind from options. + // + // "For bind mounts (when options include either bind or rbind), the type is + // a dummy, often "none" (not listed in /proc/filesystems)." + if newType != "bind" { + newOpts := make([]string, 0, len(m.Options)) + for _, opt := range m.Options { + if opt != "rbind" && opt != "bind" { + newOpts = append(newOpts, opt) + } + } + m.Options = newOpts } - return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) } diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go index 3e02c6151..5db43cdf1 100644 --- a/pkg/shim/utils/volumes_test.go +++ b/pkg/shim/utils/volumes_test.go @@ -15,11 +15,9 @@ package utils import ( - "encoding/json" "fmt" "io/ioutil" "os" - "path/filepath" "reflect" "testing" @@ -47,60 +45,60 @@ func TestUpdateVolumeAnnotations(t *testing.T) { } for _, test := range []struct { - desc string + name string spec *specs.Spec expected *specs.Spec expectErr bool expectUpdate bool }{ { - desc: "volume annotations for sandbox", + name: "volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "volume annotations for sandbox with legacy log path", + name: "volume annotations for sandbox with legacy log path", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLegacyLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLegacyLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "tmpfs: volume annotations for container", + name: "tmpfs: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -117,10 +115,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -139,16 +137,16 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "bind: volume annotations for container", + name: "bind: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -159,10 +157,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -175,63 +173,63 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "should not return error without pod log directory", + name: "should not return error without pod log directory", spec: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, }, { - desc: "should return error if volume path does not exist", + name: "should return error if volume path does not exist", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount.notexist.share": "pod", - "dev.gvisor.spec.mount.notexist.type": "tmpfs", - "dev.gvisor.spec.mount.notexist.options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + "notexist.share": "pod", + volumeKeyPrefix + "notexist.type": "tmpfs", + volumeKeyPrefix + "notexist.options": "ro", }, }, expectErr: true, }, { - desc: "no volume annotations for sandbox", + name: "no volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, + ContainerTypeAnnotation: containerTypeSandbox, }, }, expected: &specs.Spec{ Annotations: map[string]string{ sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, + ContainerTypeAnnotation: containerTypeSandbox, }, }, }, { - desc: "no volume annotations for container", + name: "no volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -248,7 +246,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, + ContainerTypeAnnotation: ContainerTypeContainer, }, }, expected: &specs.Spec{ @@ -267,17 +265,51 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, + ContainerTypeAnnotation: ContainerTypeContainer, }, }, }, + { + name: "bind options removed", + spec: &specs.Spec{ + Annotations: map[string]string{ + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro", "bind", "rbind"}, + }, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + }, + expectUpdate: true, + }, } { - t.Run(test.desc, func(t *testing.T) { - bundle, err := ioutil.TempDir(dir, "test-bundle") - if err != nil { - t.Fatalf("Create test bundle: %v", err) - } - err = UpdateVolumeAnnotations(bundle, test.spec) + t.Run(test.name, func(t *testing.T) { + updated, err := UpdateVolumeAnnotations(test.spec) if test.expectErr { if err == nil { t.Fatal("Expected error, but got nil") @@ -290,18 +322,8 @@ func TestUpdateVolumeAnnotations(t *testing.T) { if !reflect.DeepEqual(test.expected, test.spec) { t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) } - if test.expectUpdate { - b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) - if err != nil { - t.Fatalf("Read spec from bundle: %v", err) - } - var spec specs.Spec - if err := json.Unmarshal(b, &spec); err != nil { - t.Fatalf("Unmarshal spec: %v", err) - } - if !reflect.DeepEqual(test.expected, &spec) { - t.Fatalf("Expected %+v, got %+v", test.expected, &spec) - } + if test.expectUpdate != updated { + t.Errorf("Expected %v, got %v", test.expected, updated) } }) } diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index d6c89c7e9..08d06e37b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["statefile.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/binary", "//pkg/compressio", "//pkg/state/wire", ], diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index bdfb800fb..d27c8c8a8 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -48,6 +48,7 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "encoding/json" "fmt" "hash" @@ -55,7 +56,6 @@ import ( "strings" "time" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -90,6 +90,13 @@ type WriteCloser interface { io.Closer } +func writeMetadataLen(w io.Writer, val uint64) error { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], val) + _, err := w.Write(buf[:]) + return err +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. @@ -127,7 +134,7 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser } // Metadata length. - if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil { + if err := writeMetadataLen(mw, uint64(len(b))); err != nil { return nil, err } // Metadata bytes; io.MultiWriter will return a short write error if @@ -158,6 +165,14 @@ func MetadataUnsafe(r io.Reader) (map[string]string, error) { return metadata(r, nil) } +func readMetadataLen(r io.Reader) (uint64, error) { + var buf [8]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return 0, err + } + return binary.BigEndian.Uint64(buf[:]), nil +} + // metadata validates the magic header and reads out the metadata from a state // data stream. func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { @@ -183,7 +198,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } }() - metadataLen, err := binary.ReadUint64(r, binary.BigEndian) + metadataLen, err := readMetadataLen(r) if err != nil { return nil, err } diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 79e564de6..90be24e15 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -38,7 +38,7 @@ var ( ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) - ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO) ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index b5881ea3c..d70521f32 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -34,24 +34,19 @@ type Error struct { // linux.Errno. noTranslation bool - // errno is the linux.Errno this Error should be translated to. nil means - // that this Error should be translated to a nil linux.Errno. - errno *linux.Errno + // errno is the linux.Errno this Error should be translated to. + errno linux.Errno } // New creates a new Error and adds a translation for it. // // New must only be called at init. -func New(message string, linuxTranslation *linux.Errno) *Error { +func New(message string, linuxTranslation linux.Errno) *Error { err := &Error{message: message, errno: linuxTranslation} - if linuxTranslation == nil { - return err - } - // TODO(b/34162363): Remove this. - errno := linuxTranslation.Number() - if errno <= 0 || errno >= len(linuxBackwardsTranslations) { + errno := linuxTranslation + if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) { panic(fmt.Sprint("invalid errno: ", errno)) } @@ -74,7 +69,7 @@ func New(message string, linuxTranslation *linux.Errno) *Error { // NewDynamic should only be used sparingly and not be used for static error // messages. Errors with static error messages should be declared with New as // global variables. -func NewDynamic(message string, linuxTranslation *linux.Errno) *Error { +func NewDynamic(message string, linuxTranslation linux.Errno) *Error { return &Error{message: message, errno: linuxTranslation} } @@ -87,7 +82,7 @@ func NewWithoutTranslation(message string) *Error { return &Error{message: message, noTranslation: true} } -func newWithHost(message string, linuxTranslation *linux.Errno, hostErrno unix.Errno) *Error { +func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error { e := New(message, linuxTranslation) addLinuxHostTranslation(hostErrno, e) return e @@ -119,10 +114,10 @@ func (e *Error) ToError() error { if e.noTranslation { panic(fmt.Sprintf("error %q does not support translation", e.message)) } - if e.errno == nil { + errno := int(e.errno) + if errno == linux.NOERRNO { return nil } - errno := e.errno.Number() if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok { panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno)) } @@ -131,7 +126,7 @@ func (e *Error) ToError() error { // ToLinux converts the Error to a Linux ABI error that can be returned to the // application. -func (e *Error) ToLinux() *linux.Errno { +func (e *Error) ToLinux() linux.Errno { if e.noTranslation { panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message)) } diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index 7d2f5adf6..76bee5a64 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,12 +8,3 @@ go_library( visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index aa30cfc85..ea46c30da 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -22,12 +22,14 @@ go_library( "errors.go", "sock_err_list.go", "socketops.go", + "stdclock.go", + "stdclock_state.go", "tcpip.go", - "time_unsafe.go", "timer.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/atomicbitops", "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/waiter", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 12c39dfa3..18e6cc3cd 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1607,6 +1607,17 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { } } +// IPv6UnknownOption validates that an extension header option is the +// unknown header option. +func IPv6UnknownOption() IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + _, ok := opt.(*header.IPv6UnknownExtHdrOption) + if !ok { + t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) + } + } +} + // IgnoreCmpPath returns a cmp.Option that ignores listed field paths. func IgnoreCmpPath(paths ...string) cmp.Option { ignores := map[string]struct{}{} diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index ebb4b2c1d..1c913b5e1 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -60,9 +60,13 @@ func IPv4(pkt *stack.PacketBuffer) bool { return false } ipHdr = header.IPv4(hdr) + length := int(ipHdr.TotalLength()) - len(hdr) + if length < 0 { + return false + } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(length) return true } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index f75ee34ab..ef9126deb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -123,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) { q.notify = notify } +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.GSOEndpoint = (*Endpoint)(nil) + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -130,6 +133,7 @@ type Endpoint struct { mtu uint32 linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities + SupportedGSOKind stack.SupportedGSO // Outbound packet queue. q *queue @@ -211,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.LinkEPCapabilities } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (*Endpoint) GSOMaxSize() uint32 { return 1 << 15 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + return e.SupportedGSOKind +} + // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*Endpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index f042df82e..d971194e6 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/binary", "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index feb79fe0e..bddb1d0a2 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,7 +45,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string { } } +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -134,6 +136,9 @@ type endpoint struct { // wg keeps track of running goroutines. wg sync.WaitGroup + + // gsoKind is the supported kind of GSO. + gsoKind stack.SupportedGSO } // Options specify the details about the fd-based endpoint to be created. @@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) { if isSocket { if opts.GSOMaxSize != 0 { if opts.SoftwareGSOEnabled { - e.caps |= stack.CapabilitySoftwareGSO + e.gsoKind = stack.SWGSOSupported } else { - e.caps |= stack.CapabilityHardwareGSO + e.gsoKind = stack.HWGSOSupported } e.gsoMaxSize = opts.GSOMaxSize } @@ -403,6 +408,35 @@ type virtioNetHdr struct { csumOffset uint16 } +// marshal serializes h to a newly-allocated byte slice, in little-endian byte +// order. +// +// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used +// for general serialization. This makes it difficult to use go-marshal for +// virtio types, as go-marshal implicitly uses the native byte ordering. +func (h *virtioNetHdr) marshal() []byte { + buf := [virtioNetHdrSize]byte{ + 0: byte(h.flags), + 1: byte(h.gsoType), + + // Manually lay out the fields in little-endian byte order. Little endian => + // least significant bit goes to the lower address. + + 2: byte(h.hdrLen), + 3: byte(h.hdrLen >> 8), + + 4: byte(h.gsoSize), + 5: byte(h.gsoSize >> 8), + + 6: byte(h.csumStart), + 7: byte(h.csumStart >> 8), + + 8: byte(h.csumOffset), + 9: byte(h.csumOffset >> 8), + } + return buf[:] +} + // These constants are declared in linux/virtio_net.h. const ( _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 @@ -441,7 +475,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) @@ -463,7 +497,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol } } - vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf := vnetHdr.marshal() builder.Add(vnetHdrBuf) } @@ -482,7 +516,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp } var vnetHdrBuf []byte - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) @@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf = vnetHdr.marshal() } var builder iovec.Builder @@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { } } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// SupportsHWGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return e.gsoKind +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (e *endpoint) ARPHardwareType() header.ARPHardwareType { if e.hdrSize > 0 { diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index a7adf822b..4b7ef3aac 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -128,7 +128,7 @@ type readVDispatcher struct { func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil } @@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { bufs: make([]*iovecBuffer, MaxMsgsPerRecv), msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported for i := range d.bufs { d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 89df35822..3e816b0c7 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.SupportedGSO() + } + return stack.GSONotSupported +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index bba6a6973..b1a28491d 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -25,6 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + // endpoint represents a LinkEndpoint which implements a FIFO queue for all // outgoing packets. endpoint can have 1 or more underlying queueDispatchers. // All outgoing packets are consistenly hashed to a single underlying queue @@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.lower.LinkAddress() } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { if gso, ok := e.lower.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() @@ -149,6 +152,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.SupportedGSO() + } + return stack.GSONotSupported +} + // WritePacket implements stack.LinkEndpoint.WritePacket. // // The packet must have the following fields populated: diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 6905b9ccb..a72eb1aad 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -47,7 +47,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index e867b3c3f..0df39ae81 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 90075a70c..56b76a284 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragData := r.holes[i].pkt.Data() - resPkt.Data().ReadFromData(fragData, fragData.Size()) + stack.MergeFragment(resPkt, r.holes[i].pkt) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index d21b4c7ef..fd944ce99 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -6,6 +6,7 @@ go_library( name = "ip", srcs = [ "duplicate_address_detection.go", + "errors.go", "generic_multicast_protocol.go", "stats.go", ], diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index eed49f5d2..5123b7d6a 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) } + configs.Validate() + *d = DAD{ opts: opts, configs: configs, diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go new file mode 100644 index 000000000..94f1cd1cb --- /dev/null +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -0,0 +1,85 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// ForwardingError represents an error that occured while trying to forward +// a packet. +type ForwardingError interface { + isForwardingError() + fmt.Stringer +} + +// ErrTTLExceeded indicates that the received packet's TTL has been exceeded. +type ErrTTLExceeded struct{} + +func (*ErrTTLExceeded) isForwardingError() {} + +func (*ErrTTLExceeded) String() string { return "ttl exceeded" } + +// ErrParameterProblem indicates the received packet had a problem with an IP +// parameter. +type ErrParameterProblem struct{} + +func (*ErrParameterProblem) isForwardingError() {} + +func (*ErrParameterProblem) String() string { return "parameter problem" } + +// ErrLinkLocalSourceAddress indicates the received packet had a link-local +// source address. +type ErrLinkLocalSourceAddress struct{} + +func (*ErrLinkLocalSourceAddress) isForwardingError() {} + +func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" } + +// ErrLinkLocalDestinationAddress indicates the received packet had a link-local +// destination address. +type ErrLinkLocalDestinationAddress struct{} + +func (*ErrLinkLocalDestinationAddress) isForwardingError() {} + +func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" } + +// ErrNoRoute indicates that a route for the received packet couldn't be found. +type ErrNoRoute struct{} + +func (*ErrNoRoute) isForwardingError() {} + +func (*ErrNoRoute) String() string { return "no route" } + +// ErrMessageTooLong indicates the packet was too big for the outgoing MTU. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isForwardingError() {} + +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrOther indicates the packet coould not be forwarded for a reason +// captured by the contained error. +type ErrOther struct { + Err tcpip.Error +} + +func (*ErrOther) isForwardingError() {} + +func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index ac35d81e7..d22974b12 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ip holds IPv4/IPv6 common utilities. package ip import ( diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index d06b26309..0c2b62127 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -16,80 +16,145 @@ package ip import "gvisor.dev/gvisor/pkg/tcpip" +// LINT.IfChange(MultiCounterIPForwardingStats) + +// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter +// may have several versions. +type MultiCounterIPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable tcpip.MultiCounterStat + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL tcpip.MultiCounterStat + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource tcpip.MultiCounterStat + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination tcpip.MultiCounterStat + + // PacketTooBig is the number of IP packets which were dropped because they + // were too big for the outgoing MTU. + PacketTooBig tcpip.MultiCounterStat + + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem tcpip.MultiCounterStat + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) + m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) +} + +// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) + // LINT.IfChange(MultiCounterIPStats) // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the number of IP packets received from the link layer. + // PacketsReceived is the number of IP packets received from the link + // layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the number of IP packets received from the link - // layer when the IP layer is disabled. + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the number of IP packets received from + // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the number of IP packets received - // with an unknown or invalid destination address. + // InvalidDestinationAddressesReceived is the number of IP packets + // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the number of IP packets received with a - // source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received + // with a source address that should never have been received on the + // wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the number of incoming IP packets that are successfully + // PacketsDelivered is the number of incoming IP packets successfully // delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the number of IP packets which failed to write to a - // link-layer endpoint. + // OutgoingPacketErrors is the number of IP packets which failed to + // write to a link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the number of IP Packets that were dropped due - // to the IP packet header failing validation checks. + // MalformedPacketsReceived is the number of IP Packets that were + // dropped due to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the number of IP Fragments that were dropped - // due to the fragment failing validation checks. + // MalformedFragmentsReceived is the number of IP Fragments that were + // dropped due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the number of IP packets dropped in the Input - // chain. + // IPTablesInputDropped is the number of IP packets dropped in the + // Input chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the number of IP packets dropped in the Output - // chain. + // IPTablesForwardDropped is the number of IP packets dropped in the + // Forward chain. + IPTablesForwardDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the number of IP packets dropped in the + // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat - // IPTablesPostroutingDropped is the number of IP packets dropped in the - // Postrouting chain. + // IPTablesPostroutingDropped is the number of IP packets dropped in + // the Postrouting chain. IPTablesPostroutingDropped tcpip.MultiCounterStat - // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option + // stats out of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat - // OptionRecordRouteReceived is the number of Record Route options seen. + // OptionRecordRouteReceived is the number of Record Route options + // seen. OptionRecordRouteReceived tcpip.MultiCounterStat - // OptionRouterAlertReceived is the number of Router Alert options seen. + // OptionRouterAlertReceived is the number of Router Alert options + // seen. OptionRouterAlertReceived tcpip.MultiCounterStat // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived tcpip.MultiCounterStat + + // Forwarding collects stats related to IP forwarding. + Forwarding MultiCounterIPForwardingStats } // Init sets internal counters to track a and b counters. func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived) m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) @@ -100,12 +165,14 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) + m.Forwarding.Init(&a.Forwarding, &b.Forwarding) } // LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index 1c4f583c7..cec3e62c4 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -4,10 +4,7 @@ package(licenses = ["notice"]) go_library( name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], + srcs = ["testutil.go"], visibility = [ "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/internal/fragmentation:__pkg__", diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index e2cf24b67..605e9ef8d 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -19,8 +19,6 @@ package testutil import ( "fmt" "math/rand" - "reflect" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7ee0495d9..c90974693 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -62,7 +62,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index f663fdc0b..d1a82b584 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet return } - // Skip the ip header, then deliver the error. - pkt.Data().TrimFront(hlen) + // Keep needed information before trimming header. p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) + dstAddr := hdr.DestinationAddress() + // Skip the ip header, then deliver the error. + pkt.Data().DeleteFront(hlen) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data().TrimFront(header.ICMPv4MinimumSize) - switch h.Code() { + mtu := h.MTU() + code := h.Code() + pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: - networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } @@ -383,6 +387,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // icmpReason is a marker interface for IPv4 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. isForwarding() bool } @@ -442,6 +448,39 @@ func (r *icmpReasonParamProblem) isForwarding() bool { return r.forwarding } +// icmpReasonNetworkUnreachable is an error in which the network specified in +// the internet destination field of the datagram is unreachable. +type icmpReasonNetworkUnreachable struct{} + +func (*icmpReasonNetworkUnreachable) isICMPReason() {} +func (*icmpReasonNetworkUnreachable) isForwarding() bool { + // If we hit a Net Unreachable error, then we know we are operating as + // a router. As per RFC 792 page 5, Destination Unreachable Message, + // + // If, according to the information in the gateway's routing tables, + // the network specified in the internet destination field of a + // datagram is unreachable, e.g., the distance to the network is + // infinity, the gateway may send a destination unreachable message to + // the internet source host of the datagram. + return true +} + +// icmpReasonFragmentationNeeded is an error where a packet requires +// fragmentation while also having the Don't Fragment flag set, as per RFC 792 +// page 3, Destination Unreachable Message. +type icmpReasonFragmentationNeeded struct{} + +func (*icmpReasonFragmentationNeeded) isICMPReason() {} +func (*icmpReasonFragmentationNeeded) isForwarding() bool { + // If we hit a Don't Fragment error, then we know we are operating as a router. + // As per RFC 792 page 4, Destination Unreachable Message, + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -610,6 +649,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetworkUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4NetUnreachable) + counter = sent.dstUnreachable + case *icmpReasonFragmentationNeeded: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a0bc06465..23178277a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -62,9 +63,15 @@ const ( fragmentblockSize = 8 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -81,6 +88,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -150,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { e.mu.Lock() defer e.mu.Unlock() + if !e.setForwarding(forwarding) { + return + } + if forwarding { // There does not seem to be an RFC requirement for a node to join the all // routers multicast address but @@ -433,6 +464,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn } if packetMustBeFragmented(pkt, networkMTU) { + h := header.IPv4(pkt.NetworkHeader().View()) + if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment + // is set but the packet must be fragmented for the non-forwarding case. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -599,22 +636,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { - // As per RFC 3927 section 7, - // - // A router MUST NOT forward a packet with an IPv4 Link-Local source or - // destination address, irrespective of the router's default route - // configuration or routes obtained from dynamic routing protocols. - // - // A router which receives a packet with an IPv4 Link-Local source or - // destination address MUST NOT forward the packet. This prevents - // forwarding of packets back onto the network segment from which they - // originated, or to any other segment. - return nil + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } ttl := h.TTL() @@ -624,7 +664,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // If the gateway processing a datagram finds the time to live field // is zero it must discard the datagram. The gateway may also notify // the source host via the time exceeded message. - return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } if opts := h.Options(); len(opts) != 0 { @@ -635,10 +680,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { pointer: optProblem.Pointer, forwarding: true, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - e.stats.ip.MalformedPacketsReceived.Increment() } - return nil // option problems are not reported locally. + return &ip.ErrParameterProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -655,18 +698,44 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -680,10 +749,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 792, page 4, Destination Unreachable: + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + // + // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit + // while sending the packet and returns this error iff fragmentation is + // necessary and the bit is also set. + _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -764,6 +851,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats + stats.ip.ValidPacketsReceived.Increment() srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -794,11 +882,30 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) addressEndpoint.DecRef() pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() + case *ip.ErrMessageTooLong: + stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + stats.ip.Forwarding.Errors.Increment() return } @@ -955,8 +1062,8 @@ func (e *endpoint) Close() { // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err == nil { @@ -967,8 +1074,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) } @@ -981,8 +1088,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() loopback := e.nic.IsLoopback() return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { @@ -1067,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1088,12 +1194,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - ids []uint32 hashIV uint32 @@ -1206,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 7d413c455..da9cc0ae8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -112,67 +112,103 @@ func TestExcludeBroadcast(t *testing.T) { }) } +type forwardedPacket struct { + fragments []fragmentInfo +} + func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 randomTimeOffset = 0x10203040 ) - ipv4Addr1 := tcpip.AddressWithPrefix{ + incomingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, } - ipv4Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), PrefixLen: 8, } - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") + multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code + name string + TTL uint8 + sourceAddr tcpip.Address + destAddr tcpip.Address + expectErrorICMP bool + ipFlags uint8 + mtu uint32 + payloadLength int + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectPacketForwarded bool + expectedFragmentsForwarded []fragmentInfo }{ { name: "TTL of zero", TTL: 0, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, + name: "TTL of one", + TTL: 1, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + name: "four EOL options", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { - name: "TS type 1 full", - TTL: 2, + name: "TS type 1 full", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, @@ -183,8 +219,11 @@ func TestForwarding(t *testing.T) { icmpCode: header.ICMPv4UnusedCode, }, { - name: "TS type 0", - TTL: 2, + name: "TS type 0", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, @@ -201,10 +240,14 @@ func TestForwarding(t *testing.T) { 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, + expectPacketForwarded: true, }, { - name: "end of options list", - TTL: 2, + name: "end of options list", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, @@ -220,11 +263,89 @@ func TestForwarding(t *testing.T) { 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, + expectPacketForwarded: true, + }, + { + name: "Network unreachable", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: unreachableIPv4Addr, + expectErrorICMP: true, + mtu: ipv4.MaxTotalSize, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4NetUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + destAddr: multicastIPv4Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: linkLocalIPv4Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv4Addr, + destAddr: remoteIPv4Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Fragmentation needed and DF set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + ipFlags: header.IPv4FlagDontFragment, + // We've picked this MTU because it is: + // + // 1) Greater than the minimum MTU that IPv4 hosts are required to process + // (576 bytes). As per RFC 1812, Section 4.3.2.3: + // + // The ICMP datagram SHOULD contain as much of the original datagram as + // possible without the length of the ICMP datagram exceeding 576 bytes. + // + // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a + // complete ICMP packet on the incoming endpoint (and make assertions about + // it). + // + // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose + // size exceeds the MTU. + mtu: 1000, + payloadLength: 1004, + expectErrorICMP: true, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4FragmentationNeeded, + }, + { + name: "Fragmentation needed and DF not set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: 1000, + payloadLength: 1004, + expectPacketForwarded: true, + // Combined, these fragments have length of 1012 octets, which is equal to + // the length of the payload (1004 octets), plus the length of the ICMP + // header (8 octets). + expectedFragmentsForwarded: []fragmentInfo{ + // The first fragment has a length of the greatest multiple of 8 which is + // less than or equal to to `mtu - header.IPv4MinimumSize`. + {offset: 0, payloadSize: uint16(976), more: true}, + // The next fragment holds the rest of the packet. + {offset: uint16(976), payloadSize: 36, more: false}, + }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, @@ -236,46 +357,52 @@ func TestForwarding(t *testing.T) { clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) } - e2 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + expectedEmittedPacketCount := 1 + if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { + expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) + } + outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv4Addr.Subnet(), + NIC: incomingNICID, }, { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, + Destination: outgoingIPv4Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpHeaderLength := header.ICMPv4MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv4Echo) @@ -284,11 +411,12 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(^header.Checksum(icmp, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, + TotalLength: uint16(totalLength), Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, + Flags: test.ipFlags, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) @@ -303,51 +431,122 @@ func TestForwarding(t *testing.T) { requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } + // We expect the ICMP packet to contain as much of the original packet as + // possible up to a limit of 576 bytes, split between payload, IP header, + // and ICMP header. + expectedICMPPayloadLength := func() int { + maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize + maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) + } + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), + checker.SrcAddr(incomingIPv4Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), ), ) + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) + if test.expectPacketForwarded { + if len(test.expectedFragmentsForwarded) != 0 { + fragmentedPackets := []*stack.PacketBuffer{} + for i := 0; i < len(test.expectedFragmentsForwarded); i++ { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo fragment through outgoing NIC") + } + fragmentedPackets = append(fragmentedPackets, reply.Pkt) + } + + // The forwarded packet's TTL will have been decremented. + ipHeader := header.IPv4(requestPkt.NetworkHeader().View()) + ipHeader.SetTTL(ipHeader.TTL() - 1) + + // Forwarded packets have available header bytes equalling the sum of the + // maximum IP header size and the maximum size allocated for link layer + // headers. In this case, no size is allocated for link layer headers. + expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize + if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + t.Error(err) + } + } else { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), + checker.TTL(test.TTL-1), + checker.IPv4Options(test.forwardedOptions), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) } } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") + if reply, ok = outgoingEndpoint.Read(); ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + } + boolToInt := func(val bool) uint64 { + if val { + return 1 } + return 0 + } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { + t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } @@ -1170,13 +1369,25 @@ func TestIPv4Sanity(t *testing.T) { } } -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { +// compareFragments compares the contents of a set of fragmented packets against +// the contents of a source packet. +// +// If withIPHeader is set to true, we will validate the fragmented packets' IP +// headers against the source packet's IP header. If set to false, we validate +// the fragmented packets' IP headers against each other. +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error { // Make a complete array of the sourcePacket packet. - source := header.IPv4(packets[0].NetworkHeader().View()) + var source header.IPv4 vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) + + // If the packet to be fragmented contains an IPv4 header, use that header for + // validating fragment headers. Else, use the header of the first fragment. + if withIPHeader { + source = header.IPv4(vv.ToView()) + } else { + source = header.IPv4(packets[0].NetworkHeader().View()) + source = append(source, vv.ToView()...) + } // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -1199,12 +1410,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB if got := fragmentIPHeader.TransportProtocol(); got != proto { return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } + if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes) + } if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } @@ -1220,6 +1431,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) + + // If we are validating against the original IP header, we should exclude the + // ID field, which will only be set fo fragmented packets. + if withIPHeader { + fragmentIPHeader.SetID(0) + fragmentIPHeader.SetChecksum(0) + fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum()) + } if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } @@ -1348,7 +1567,7 @@ func TestFragmentationWritePacket(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) @@ -1429,7 +1648,7 @@ func TestFragmentationWritePackets(t *testing.T) { } fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index a637f9d50..d1f9e3cf5 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index db998e83e..f99cbf8f3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1319db32b..307e1972d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe return } + // Keep needed information before trimming header. + p := hdr.TransportProtocol() + dstAddr := hdr.DestinationAddress() + // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6MinimumSize) if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // because they don't have the transport headers. return } + p = fragHdr.TransportProtocol() // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) } - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } + pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: @@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + code := header.ICMPv6(hdr).Code() + pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) + switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -741,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - stack := e.protocol.stack - - // Is the networking stack operating as a router? - if !stack.Forwarding(ProtocolNumber) { - // ... No, silently drop the packet. + if !e.Forwarding() { received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -951,6 +951,19 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo // icmpReason is a marker interface for IPv6 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. + isForwarding() bool + // respondToMulticast indicates whether this error falls under the exception + // outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are two + // exceptions to this rule: (1) the Packet Too Big Message (Section 3.2) to + // allow Path MTU discovery to work for IPv6 multicast, and (2) the Parameter + // Problem Message, Code 2 (Section 3.4) reporting an unrecognized IPv6 + // option (see Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondsToMulticast() bool } // icmpReasonParameterProblem is an error during processing of extension headers @@ -958,18 +971,6 @@ type icmpReason interface { type icmpReasonParameterProblem struct { code header.ICMPv6Code - // respondToMulticast indicates that we are sending a packet that falls under - // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: - // - // (e.3) A packet destined to an IPv6 multicast address. (There are - // two exceptions to this rule: (1) the Packet Too Big Message - // (Section 3.2) to allow Path MTU discovery to work for IPv6 - // multicast, and (2) the Parameter Problem Message, Code 2 - // (Section 3.4) reporting an unrecognized IPv6 option (see - // Section 4.2 of [IPv6]) that has the Option Type highest- - // order two bits set to 10). - respondToMulticast bool - // pointer is defined in the RFC 4443 setion 3.4 which reads: // // Pointer Identifies the octet offset within the invoking packet @@ -979,9 +980,20 @@ type icmpReasonParameterProblem struct { // packet if the field in error is beyond what can fit // in the maximum size of an ICMPv6 error message. pointer uint32 + + forwarding bool + + respondToMulticast bool } func (*icmpReasonParameterProblem) isICMPReason() {} +func (p *icmpReasonParameterProblem) isForwarding() bool { + return p.forwarding +} + +func (p *icmpReasonParameterProblem) respondsToMulticast() bool { + return p.respondToMulticast +} // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. @@ -989,12 +1001,76 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} + +func (*icmpReasonPortUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonNetUnreachable is an error where no route can be found to the +// network of the final destination. +type icmpReasonNetUnreachable struct{} + +func (*icmpReasonNetUnreachable) isICMPReason() {} + +func (*icmpReasonNetUnreachable) isForwarding() bool { + // If we hit a Network Unreachable error, then we also know we are + // operating as a router. As per RFC 4443 section 3.1: + // + // If the reason for the failure to deliver is lack of a matching + // entry in the forwarding node's routing table, the Code field is + // set to 0 (Network Unreachable). + return true +} + +func (*icmpReasonNetUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent +// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. +type icmpReasonPacketTooBig struct{} + +func (*icmpReasonPacketTooBig) isICMPReason() {} + +func (*icmpReasonPacketTooBig) isForwarding() bool { + // If we hit a Packet Too Big error, then we know we are operating as a router. + // As per RFC 4443 section 3.2: + // + // A Packet Too Big MUST be sent by a router in response to a packet that it + // cannot forward because the packet is larger than the MTU of the outgoing + // link. + return true +} + +func (*icmpReasonPacketTooBig) respondsToMulticast() bool { + return true +} + // icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in // transit to its final destination, as per RFC 4443 section 3.3. type icmpReasonHopLimitExceeded struct{} func (*icmpReasonHopLimitExceeded) isICMPReason() {} +func (*icmpReasonHopLimitExceeded) isForwarding() bool { + // If we hit a Hop Limit Exceeded error, then we know we are operating + // as a router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard + // the packet and originate an ICMPv6 Time Exceeded message with Code + // 0 to the source of the packet. This indicates either a routing + // loop or too small an initial Hop Limit value. + return true +} + +func (*icmpReasonHopLimitExceeded) respondsToMulticast() bool { + return false +} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -1002,6 +1078,14 @@ type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} + +func (*icmpReasonReassemblyTimeout) respondsToMulticast() bool { + return false +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { @@ -1030,25 +1114,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - var allowResponseToMulticast bool - if reason, ok := reason.(*icmpReasonParameterProblem); ok { - allowResponseToMulticast = reason.respondToMulticast - } - + allowResponseToMulticast := reason.respondsToMulticast() isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } - // If we hit a Hop Limit Exceeded error, then we know we are operating as a - // router. As per RFC 4443 section 3.3: - // - // If a router receives a packet with a Hop Limit of zero, or if a - // router decrements a packet's Hop Limit to zero, it MUST discard the - // packet and originate an ICMPv6 Time Exceeded message with Code 0 to - // the source of the packet. This indicates either a routing loop or - // too small an initial Hop Limit value. - // // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. @@ -1058,7 +1129,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // packet as "multicast addresses must not be used as source addresses in IPv6 // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + if reason.isForwarding() || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -1147,6 +1218,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) + counter = sent.dstUnreachable + case *icmpReasonPacketTooBig: + icmpHdr.SetType(header.ICMPv6PacketTooBig) + icmpHdr.SetCode(header.ICMPv6UnusedCode) + counter = sent.packetTooBig case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..040cd4bc8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f7510c243..95e11ac51 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -63,6 +63,11 @@ const ( buckets = 2048 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + // policyTable is the default policy table defined in RFC 6724 section 2.1. // // A more human-readable version: @@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 { var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -187,6 +193,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -405,27 +417,39 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t } } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.setForwarding(forwarding) { + return + } + allRoutersGroups := [...]tcpip.Address{ header.IPv6AllRoutersInterfaceLocalMulticastAddress, header.IPv6AllRoutersLinkLocalMulticastAddress, header.IPv6AllRoutersSiteLocalMulticastAddress, } - e.mu.Lock() - defer e.mu.Unlock() - if forwarding { - // When transitioning into an IPv6 router, host-only state (NDP discovered - // routers, discovered on-link prefixes, and auto-generated addresses) is - // cleaned up/invalidated and NDP router solicitations are stopped. - e.mu.ndp.stopSolicitingRouters() - e.mu.ndp.cleanupState(true /* hostOnly */) - // As per RFC 4291 section 2.8: // // A router is required to recognize all addresses that a host is @@ -449,28 +473,19 @@ func (e *endpoint) transitionForwarding(forwarding bool) { panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) } } - - return - } - - for _, g := range allRoutersGroups { - switch err := e.leaveGroupLocked(g).(type) { - case nil: - case *tcpip.ErrBadLocalAddress: - // The endpoint may have already left the multicast group. - default: - panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } else { + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } } } - // When transitioning into an IPv6 host, NDP router solicitations are - // started if the endpoint is enabled. - // - // If the endpoint is not currently enabled, routers will be solicited when - // the endpoint becomes enabled (if it is still a host). - if e.Enabled() { - e.mu.ndp.startSolicitingRouters() - } + e.mu.ndp.forwardingChanged(forwarding) } // Enable implements stack.NetworkEndpoint. @@ -552,17 +567,7 @@ func (e *endpoint) Enable() tcpip.Error { e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyway. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !e.protocol.Forwarding() { - e.mu.ndp.startSolicitingRouters() - } - + e.mu.ndp.startSolicitingRouters() return nil } @@ -613,7 +618,7 @@ func (e *endpoint) disableLocked() { return true }) - e.mu.ndp.cleanupState(false /* hostOnly */) + e.mu.ndp.cleanupState() // The endpoint may have already left the multicast group. switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { @@ -786,6 +791,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol } if packetMustBeFragmented(pkt, networkMTU) { + if pkt.NetworkPacketInfo.IsForwardedPacket { + // As per RFC 2460, section 4.5: + // Unlike IPv4, fragmentation in IPv6 is performed only by source nodes, + // not by routers along a packet's delivery path. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -928,16 +939,19 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv6(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { - // As per RFC 4291 section 2.5.6, - // - // Routers must not forward any packets with Link-Local source or - // destination addresses to other links. - return nil + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } hopLimit := h.HopLimit() @@ -949,21 +963,56 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // packet and originate an ICMPv6 Time Exceeded message with Code 0 to // the source of the packet. This indicates either a routing loop or // too small an initial Hop Limit value. - return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + // Check extension headers for any errors requiring action during forwarding. + if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil { + return &ip.ErrParameterProblem{} + } + + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning the + // ICMP packet because the original error is more relevant to the caller. + _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -975,10 +1024,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 4443, section 3.2: + // A Packet Too Big MUST be sent by a router in response to a packet that + // it cannot forward because the packet is larger than the MTU of the + // outgoing link. + _ = e.protocol.returnError(&icmpReasonPacketTooBig{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -1059,6 +1121,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip + stats.ValidPacketsReceived.Increment() + srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1075,15 +1139,54 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.InvalidDestinationAddressesReceived.Increment() return } + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + e.stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + e.stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + e.stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + e.stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment() + case *ip.ErrMessageTooLong: + e.stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + e.stats.ip.Forwarding.Errors.Increment() + return + } - _ = e.forwardPacket(pkt) + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + // iptables is telling us to drop the packet. + stats.IPTablesInputDropped.Increment() return } + // Any returned error is only useful for terminating execution early, but + // we have nothing left to do, so we can drop it. + _ = e.processExtensionHeaders(h, pkt, false /* forwarding */) +} + +// processExtensionHeaders processes the extension headers in the given packet. +// Returns an error if the processing of a header failed or if the packet should +// be discarded. +func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error { + stats := e.stats.ip + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1094,15 +1197,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) - // iptables filtering. All packets that reach here are intended for - // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { - // iptables is telling us to drop the packet. - stats.IPTablesInputDropped.Increment() - return - } - var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption @@ -1115,22 +1209,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) extHdr, done, err := it.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break } + // As per RFC 8200, section 4: + // + // Extension headers (except for the Hop-by-Hop Options header) are + // not processed, inserted, or deleted by any node along a packet's + // delivery path until the packet reaches the node identified in the + // Destination Address field of the IPv6 header. + // + // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension + // header is restricted to appear first in the list of extension headers. + // + // Therefore, we can immediately return once we hit any header other + // than the Hop-by-Hop header while forwarding a packet. + if forwarding { + if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok { + return nil + } + } + switch extHdr := extHdr.(type) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: previousHeaderStart, + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart) } optsIt := extHdr.Iter() @@ -1139,7 +1252,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1154,7 +1267,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // There MUST only be one option of this type, regardless of // value, per Hop-by-Hop header. stats.MalformedPacketsReceived.Increment() - return + return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert) } routerAlert = opt stats.OptionRouterAlertReceived.Increment() @@ -1162,10 +1275,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1180,10 +1293,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt)) } } } @@ -1205,8 +1319,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), + // For the sake of consistency, we're using the value of `forwarding` + // here, even though it should always be false if we've reached this + // point. If `forwarding` is true here, we're executing undefined + // behavior no matter what. + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr) } case header.IPv6FragmentExtHdr: @@ -1241,7 +1360,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if done { break @@ -1269,7 +1388,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) default: stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr) } } @@ -1278,7 +1397,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Drop the packet as it's marked as a fragment but has no payload. stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("fragment has no payload") } // As per RFC 2460 Section 4.5: @@ -1296,7 +1415,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) - return + return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen) } // The packet is a fragment, let's try to reassemble it. @@ -1310,14 +1429,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Parameter Problem, Code 0, message should be sent to the source of // the fragment, pointing to the Fragment Offset field of the fragment // packet. - if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { + lengthAfterReassembly := int(start) + fragmentPayloadLen + if lengthAfterReassembly > header.IPv6MaximumPayloadSize { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) - return + return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize) } // Note that pkt doesn't have its transport header set after reassembly, @@ -1339,7 +1459,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if ready { @@ -1361,7 +1481,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1372,10 +1492,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown destination header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1392,9 +1512,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, }, pkt) - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt)) } } @@ -1402,13 +1522,19 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + // Calculate the number of octets parsed from data. We want to remove all + // the data except the unparsed portion located at the end, which its size + // is extHdr.Buf.Size(). + trim := pkt.Data().Size() - extHdr.Buf.Size() + // For unfragmented packets, extHdr still contains the transport header. // Get rid of it. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data().Replace(extHdr.Buf) + trim += pkt.TransportHeader().View().Size() + + pkt.Data().DeleteFront(trim) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -1425,6 +1551,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) + return fmt.Errorf("destination port unreachable") case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1456,6 +1583,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownHeader, pointer: prevHdrIDOffset, }, pkt) + return fmt.Errorf("transport protocol unreachable") default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -1469,6 +1597,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) } } + return nil } // Close cleans up resources associated with the endpoint. @@ -1490,8 +1619,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1532,8 +1661,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -1610,8 +1739,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) } @@ -1833,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1858,12 +1986,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - fragmentation *fragmentation.Fragmentation } @@ -2038,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 40a793d6b..afc6c3547 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,8 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -3003,52 +3004,289 @@ func TestFragmentationErrors(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 ) - ipv6Addr1 := tcpip.AddressWithPrefix{ + incomingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10::1").To16()), PrefixLen: 64, } - ipv6Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11::1").To16()), PrefixLen: 64, } + multicastIPv6Addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("ff00::").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) + linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) + TTL uint8 + expectErrorICMP bool + expectPacketForwarded bool + payloadLength int + countUnrouteablePackets uint64 + sourceAddr tcpip.Address + destAddr tcpip.Address + icmpType header.ICMPv6Type + icmpCode header.ICMPv6Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectExtensionHeaderError bool }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "TTL of three", + TTL: 3, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Network unreachable", + TTL: 2, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: unreachableIPv6Addr, + icmpType: header.ICMPv6DstUnreachable, + icmpCode: header.ICMPv6NetworkUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + countUnrouteablePackets: 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectPacketForwarded: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: linkLocalIPv6Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv6Addr, + destAddr: remoteIPv6Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Hopbyhop with unknown option skippable action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption())) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with unknown option discard action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, }, { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, + name: "Hopbyhop with router alert option", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD))) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with two router alert options", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Can't fragment", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Can't fragment multicast", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectErrorICMP: true, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, } @@ -3059,41 +3297,60 @@ func TestForwarding(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) } - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv6Addr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: outgoingIPv6Addr.Subnet(), + NIC: outgoingNICID, }, { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, + Destination: multicastIPv6Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) - icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + transportProtocol := header.ICMPv6ProtocolNumber + extHdrBytes := []byte{} + extHdrChecker := checker.IPv6ExtHdr() + if test.extHdr != nil { + nextHdrID := hopByHopExtHdrID + extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber)) + transportProtocol = tcpip.TransportProtocolNumber(nextHdrID) + } + extHdrLen := len(extHdrBytes) + + ipHeaderLength := header.IPv6MinimumSize + icmpHeaderLength := header.ICMPv6MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv6EchoRequest) @@ -3101,52 +3358,72 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, + Src: test.sourceAddr, + Dst: test.destAddr, })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + copy(hdr.Prepend(extHdrLen), extHdrBytes) + ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, + PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + TransportProtocol: transportProtocol, HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(ProtocolNumber, requestPkt) + incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) + } + + // As per RFC 4443, page 9: + // + // The returned ICMP packet will contain as much of invoking packet + // as possible without the ICMPv6 packet exceeding the minimum IPv6 + // MTU. + expectedICMPPayloadLength := func() int { + maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), + checker.SrcAddr(incomingIPv6Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), - checker.ICMPv6Payload([]byte(hdr.View())), + checker.ICMPv6Type(test.icmpType), + checker.ICMPv6Code(test.icmpCode), + checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), ), ) - if n := e2.Drain(); n != 0 { + if n := outgoingEndpoint.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = outgoingEndpoint.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), + checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), + extHdrChecker, checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), checker.ICMPv6Code(header.ICMPv6UnusedCode), @@ -3154,9 +3431,46 @@ func TestForwarding(t *testing.T) { ), ) - if n := e1.Drain(); n != 0 { + if n := incomingEndpoint.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { + t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d6e0a81a6..f0ff111c5 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -48,7 +48,7 @@ const ( // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - defaultHandleRAs = true + defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router @@ -301,10 +301,60 @@ type NDPDispatcher interface { OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } +var _ fmt.Stringer = HandleRAsConfiguration(0) + +// HandleRAsConfiguration enumerates when RAs may be handled. +type HandleRAsConfiguration int + +const ( + // HandlingRAsDisabled indicates that Router Advertisements will not be + // handled. + HandlingRAsDisabled HandleRAsConfiguration = iota + + // HandlingRAsEnabledWhenForwardingDisabled indicates that router + // advertisements will only be handled when forwarding is disabled. + HandlingRAsEnabledWhenForwardingDisabled + + // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always + // be handled, even when forwarding is enabled. + HandlingRAsAlwaysEnabled +) + +// String implements fmt.Stringer. +func (c HandleRAsConfiguration) String() string { + switch c { + case HandlingRAsDisabled: + return "HandlingRAsDisabled" + case HandlingRAsEnabledWhenForwardingDisabled: + return "HandlingRAsEnabledWhenForwardingDisabled" + case HandlingRAsAlwaysEnabled: + return "HandlingRAsAlwaysEnabled" + default: + return fmt.Sprintf("HandleRAsConfiguration(%d)", c) + } +} + +// enabled returns true iff Router Advertisements may be handled given the +// specified forwarding status. +func (c HandleRAsConfiguration) enabled(forwarding bool) bool { + switch c { + case HandlingRAsDisabled: + return false + case HandlingRAsEnabledWhenForwardingDisabled: + return !forwarding + case HandlingRAsAlwaysEnabled: + return true + default: + panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c)) + } +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. + // + // Ignored unless configured to handle Router Advertisements. MaxRtrSolicitations uint8 // The amount of time between transmitting Router Solicitation messages. @@ -318,8 +368,9 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements are processed. - HandleRAs bool + // HandleRAs is the configuration for when Router Advertisements should be + // handled. + HandleRAs HandleRAsConfiguration // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This @@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t delete(tempAddrs, tempAddr) } -// removeSLAACAddresses removes all SLAAC addresses. -// -// If keepLinkLocal is false, the SLAAC generated link-local address is removed. -// -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - var linkLocalPrefixes int - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if keepLinkLocal && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } -} - // cleanupState cleans up ndp's state. // -// If hostOnly is true, then only host-specific state is cleaned up. -// // This function invalidates all discovered on-link prefixes, discovered // routers, and auto-generated addresses. // -// If hostOnly is true, then the link-local auto-generated address aren't -// invalidated as routers are also expected to generate a link-local address. -// // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) +func (ndp *ndpState) cleanupState() { + for prefix, state := range ndp.slaacPrefixes { + ndp.invalidateSLAACPrefix(prefix, state) + } for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // +// If ndp is not configured to handle Router Advertisements, routers will not +// be solicited as there is no point soliciting routers if we don't handle their +// advertisements. +// // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicitTimer.timer != nil { @@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() { return } + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + return + } + // Calculate the random delay before sending our first RS, as per RFC // 4861 section 6.3.7. var delay time.Duration @@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() { } } +// forwardingChanged handles a change in forwarding configuration. +// +// If transitioning to a host, router solicitation will be started. Otherwise, +// router solicitation will be stopped if NDP is not configured to handle RAs +// as a router. +// +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) forwardingChanged(forwarding bool) { + if forwarding { + if ndp.configs.HandleRAs.enabled(forwarding) { + return + } + + ndp.stopSolicitingRouters() + return + } + + // Solicit routers when transitioning to a host. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if ndp.ep.Enabled() { + ndp.startSolicitingRouters() + } +} + // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..234e34952 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { } func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } + const nicID = 1 handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { var extHdrs header.IPv6ExtHdrSerializer @@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) { }, } + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + for _, typ := range types { for _, isRouter := range []bool{false, true} { name := typ.name @@ -875,13 +872,35 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatal("cannot find network endpoint instance for IPv6") + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}) + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost @@ -906,12 +925,12 @@ func TestNDPValidation(t *testing.T) { // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + t.Errorf("got invalid.Value() = %d, want = 0", got) } - // RouterOnlyPacketsReceivedByHost count should initially be 0. + // Should initially not have dropped any packets. if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + t.Errorf("got routerOnly.Value() = %d, want = 0", got) } if t.Failed() { @@ -931,18 +950,18 @@ func TestNDPValidation(t *testing.T) { want = 1 } if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + t.Errorf("got invalid.Value() = %d, want = %d", got, want) } want = 0 if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. + // Router only packets are expected to be dropped when operating + // as a host. want = 1 } if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) } - }) } }) diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index c2758352f..2f18f60e8 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -29,6 +29,10 @@ type Stats struct { // ICMP holds ICMPv6 statistics. ICMP tcpip.ICMPv6Stats + + // UnhandledRouterAdvertisements is the number of Router Advertisements that + // were observed but not handled. + UnhandledRouterAdvertisements *tcpip.StatCounter } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index a6c877158..b26936b7f 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sync" ) @@ -213,7 +214,7 @@ type SocketOptions struct { getSendBufferLimits GetSendBufferLimits `state:"manual"` // sendBufferSize determines the send buffer size for this socket. - sendBufferSize int64 + sendBufferSize atomicbitops.AlignedAtomicInt64 // getReceiveBufferLimits provides the handler to get the min, default and // max size for receive buffer. It is initialized at the creation time and @@ -612,7 +613,7 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { // GetSendBufferSize gets value for SO_SNDBUF option. func (so *SocketOptions) GetSendBufferSize() int64 { - return atomic.LoadInt64(&so.sendBufferSize) + return so.sendBufferSize.Load() } // SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the @@ -621,7 +622,7 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { v := sendBufferSize if !notify { - atomic.StoreInt64(&so.sendBufferSize, v) + so.sendBufferSize.Store(v) return } @@ -647,7 +648,7 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // Notify endpoint about change in buffer size. newSz := so.handler.OnSetSendBufferSize(v) - atomic.StoreInt64(&so.sendBufferSize, newSz) + so.sendBufferSize.Store(newSz) } // GetReceiveBufferSize gets value for SO_RCVBUF option. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2bd6a67f5..84aa6a9e4 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -73,6 +73,8 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/atomicbitops", + "//pkg/buffer", "//pkg/ilist", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index e5590ecc0..ce9cebdaa 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad // Regardless how the address was obtained, it will be acquired before it is // returned. func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { - a.mu.Lock() - defer a.mu.Unlock() + lookup := func() *addressState { + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } - if addrState, ok := a.mu.endpoints[localAddr]; ok { - if !addrState.IsAssigned(allowTemp) { - return nil - } + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } - if !addrState.IncRef() { - panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + return addrState } - return addrState - } - - if f != nil { - for _, addrState := range a.mu.endpoints { - if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { - return addrState + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } } } + return nil + } + // Avoid exclusive lock on mu unless we need to add a new address. + a.mu.RLock() + ep := lookup() + a.mu.RUnlock() + + if ep != nil { + return ep } if !allowTemp { return nil } + // Acquire state lock in exclusive mode as we need to add a new temporary + // endpoint. + a.mu.Lock() + defer a.mu.Unlock() + + // Do the lookup again in case another goroutine added the address in the time + // we released and acquired the lock. + ep = lookup() + if ep != nil { + return ep + } + + // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) if err != nil { @@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // expect no error. panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) } + // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 2d74e0abc..7107d598d 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct { nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + + mu struct { + sync.RWMutex + forwarding bool + } } func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { @@ -101,7 +106,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -169,11 +174,6 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -242,16 +242,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -264,6 +264,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -370,8 +367,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index e2894c548..3670d5995 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -177,6 +177,7 @@ func DefaultTables() *IPTables { priorities: [NumHooks][]TableID{ Prerouting: {MangleID, NATID}, Input: {NATID, FilterID}, + Forward: {FilterID}, Output: {MangleID, NATID, FilterID}, Postrouting: {MangleID, NATID}, }, diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4631ab93f..93592e7f5 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) case Output: return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) - case Forward, Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. + case Forward: + if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) { + return false + } + + if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) { + return false + } + + return true + case Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b6cf24739..ac2fa777e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -481,13 +481,9 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channelLinkWithHeaderLength{ @@ -499,7 +495,9 @@ func TestDADResolve(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ + Clock: clock, SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, @@ -529,14 +527,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + const delta = time.Nanosecond + clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Error(err) } @@ -566,13 +560,14 @@ func TestDADResolve(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { t.Error(err) @@ -1146,57 +1141,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1207,7 +1343,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1241,103 +1377,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1351,7 +1493,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1390,57 +1532,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1459,8 +1550,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1498,87 +1588,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1607,7 +1703,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1692,7 +1788,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1757,53 +1853,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1812,7 +1861,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1824,96 +1873,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -2001,7 +2056,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2302,7 +2357,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2389,7 +2444,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2538,7 +2593,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2739,7 +2794,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2884,7 +2939,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3351,7 +3406,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3494,7 +3549,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3561,7 +3616,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3727,7 +3782,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3809,7 +3864,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3973,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -4000,7 +4055,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4150,7 +4205,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4278,7 +4333,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4484,7 +4539,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4535,7 +4590,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4629,8 +4684,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4659,18 +4816,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4722,7 +4867,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4995,7 +5140,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5186,96 +5331,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5300,11 +5476,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, @@ -5373,6 +5555,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..9821a18d3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8d615500f..dbba2c79f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1000,3 +1000,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t return d.CheckDuplicateAddress(addr, h), nil } + +func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return &tcpip.ErrNotSupported{} + } + + forwardingEP.SetForwarding(enable) + return nil +} + +func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return false, &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return false, &tcpip.ErrNotSupported{} + } + + return forwardingEP.Forwarding(), nil +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 646979d1e..4ca702121 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -16,9 +16,10 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -39,7 +40,11 @@ type PacketBufferOptions struct { // Data is the initial unparsed data for the new packet. If set, it will be // owned by the new packet. - Data buffer.VectorisedView + Data tcpipbuffer.VectorisedView + + // IsForwardedPacket identifies that the PacketBuffer being created is for a + // forwarded packet. + IsForwardedPacket bool } // A PacketBuffer contains all the data of a network packet. @@ -52,6 +57,34 @@ type PacketBufferOptions struct { // empty. Use of PacketBuffer in any other order is unsupported. // // PacketBuffer must be created with NewPacketBuffer. +// +// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which +// exposes a logically-contiguous byte storage. The underlying storage structure +// is abstracted out, and should not be a concern here for most of the time. +// +// |- reserved ->| +// |--->| consumed (incoming) +// 0 V V +// +--------+----+----+--------------------+ +// | | | | current data ... | (buf) +// +--------+----+----+--------------------+ +// ^ | +// |<---| pushed (outgoing) +// +// When a PacketBuffer is created, a `reserved` header region can be specified, +// which stack pushes headers in this region for an outgoing packet. There could +// be no such region for an incoming packet, and `reserved` is 0. The value of +// `reserved` never changes in the entire lifetime of the packet. +// +// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the +// pushed length, and the current value is stored for each header. PacketBuffer +// substracts this value from `reserved` to compute the starting offset of each +// header in `buf`. +// +// Incoming Packet: When a header is consumed (a.k.a. parsed), the current +// `consumed` value is stored for each header, and it gets incremented by the +// consumed length. PacketBuffer adds this value to `reserved` to compute the +// starting offset of each header in `buf`. type PacketBuffer struct { _ sync.NoCopy @@ -59,28 +92,16 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // data holds the payload of the packet. - // - // For inbound packets, Data is initially the whole packet. Then gets moved to - // headers via PacketHeader.Consume, when the packet is being parsed. - // - // For outbound packets, Data is the innermost layer, defined by the protocol. - // Headers are pushed in front of it via PacketHeader.Push. - // - // The bytes backing Data are immutable, a.k.a. users shouldn't write to its - // backing storage. - data buffer.VectorisedView + // buf is the underlying buffer for the packet. See struct level docs for + // details. + buf *buffer.Buffer + reserved int + pushed int + consumed int // headers stores metadata about each header. headers [numHeaderType]headerInfo - // header is the internal storage for outbound packets. Headers will be pushed - // (prepended) on this storage as the packet is being constructed. - // - // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and - // data are held in the same underlying buffer storage. - header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -127,10 +148,17 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - data: opts.Data, + buf: &buffer.Buffer{}, } if opts.ReserveHeaderBytes != 0 { - pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes)) + pk.reserved = opts.ReserveHeaderBytes + } + for _, v := range opts.Data.Views() { + pk.buf.AppendOwned(v) + } + if opts.IsForwardedPacket { + pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket } return pk } @@ -138,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { - return pk.header.UsedLength() + pk.header.AvailableLength() + return pk.reserved } // AvailableHeaderBytes returns the number of bytes currently available for // headers. This is relevant to PacketHeader.Push method only. func (pk *PacketBuffer) AvailableHeaderBytes() int { - return pk.header.AvailableLength() + return pk.reserved - pk.pushed } // LinkHeader returns the handle to link-layer header. @@ -173,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader { // HeaderSize returns the total size of all headers in bytes. func (pk *PacketBuffer) HeaderSize() int { - // Note for inbound packets (Consume called), headers are not stored in - // pk.header. Thus, calculation of size of each header is needed. - var size int - for i := range pk.headers { - size += len(pk.headers[i].buf) - } - return size + return pk.pushed + pk.consumed } // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.data.Size() + return int(pk.buf.Size()) - pk.headerOffset() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize + return int(pk.buf.Size()) + packetBufferStructSize } // Data returns the handle to data portion of pk. @@ -199,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData { } // Views returns the underlying storage of the whole packet. -func (pk *PacketBuffer) Views() []buffer.View { - // Optimization for outbound packets that headers are in pk.header. - useHeader := true - for i := range pk.headers { - if !canUseHeader(&pk.headers[i]) { - useHeader = false - break - } - } +func (pk *PacketBuffer) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := pk.headerOffset() + pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views +} - dataViews := pk.data.Views() - - var vs []buffer.View - if useHeader { - vs = make([]buffer.View, 0, 1+len(dataViews)) - vs = append(vs, pk.header.View()) - } else { - vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) - for i := range pk.headers { - if v := pk.headers[i].buf; len(v) > 0 { - vs = append(vs, v) - } - } - } - return append(vs, dataViews...) +func (pk *PacketBuffer) headerOffset() int { + return pk.reserved - pk.pushed +} + +func (pk *PacketBuffer) headerOffsetOf(typ headerType) int { + return pk.reserved + pk.headers[typ].offset } -func canUseHeader(h *headerInfo) bool { - // h.offset will be negative if the header was pushed in to prependable - // portion, or doesn't matter when it's empty. - return len(h.buf) == 0 || h.offset < 0 +func (pk *PacketBuffer) dataOffset() int { + return pk.reserved + pk.consumed } -func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { +func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] - if h.buf != nil { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + if h.length > 0 { + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) + } + if pk.pushed+size > pk.reserved { + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } - h.buf = buffer.View(pk.header.Prepend(size)) - h.offset = -pk.header.UsedLength() - return h.buf + pk.pushed += size + h.offset = -pk.pushed + h.length = size + return pk.headerView(typ) } -func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { +func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.data.PullUp(size) + if pk.reserved+pk.consumed+size > int(pk.buf.Size()) { + return nil, false + } + h.offset = pk.consumed + h.length = size + pk.consumed += size + return pk.headerView(typ), true +} + +func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { + h := &pk.headers[typ] + if h.length == 0 { + return nil + } + v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length) if !ok { - return + panic("PullUp failed") } - pk.data.TrimFront(size) - h.buf = v - return h.buf, true + return v } // Clone makes a shallow copy of pk. @@ -263,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - data: pk.data.Clone(nil), + buf: pk.buf, + reserved: pk.reserved, + pushed: pk.pushed, + consumed: pk.consumed, headers: pk.headers, - header: pk.header, Hash: pk.Hash, Owner: pk.Owner, GSOOptions: pk.GSOOptions, @@ -299,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - newPk := NewPacketBuffer(PacketBufferOptions{ - Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), - }) + newPk := &PacketBuffer{ + buf: pk.buf, + // Treat unfilled header portion as reserved. + reserved: pk.AvailableHeaderBytes(), + } // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to // maintain this flag in the packet. Currently conntrack needs this flag to // tell if a noop connection should be inserted at Input hook. Once conntrack @@ -315,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // headerInfo stores metadata about a header in a packet. type headerInfo struct { - // buf is the memorized slice for both prepended and consumed header. - // When header is prepended, buf serves as memorized value, which is a slice - // of pk.header. When header is consumed, buf is the slice pulled out from - // pk.Data, which is the only place to hold this header. - buf buffer.View - - // offset will be a negative number denoting the offset where this header is - // from the end of pk.header, if it is prepended. Otherwise, zero. + // offset is the offset of the header in pk.buf relative to + // pk.buf[pk.reserved]. See the PacketBuffer struct for details. offset int + + // length is the length of this header. + length int } // PacketHeader is a handle object to a header in the underlying packet. @@ -333,14 +360,14 @@ type PacketHeader struct { } // View returns the underlying storage of h. -func (h PacketHeader) View() buffer.View { - return h.pk.headers[h.typ].buf +func (h PacketHeader) View() tcpipbuffer.View { + return h.pk.headerView(h.typ) } // Push pushes size bytes in the front of its residing packet, and returns the // backing storage. Callers may only call one of Push or Consume once on each // header in the lifetime of the underlying packet. -func (h PacketHeader) Push(size int) buffer.View { +func (h PacketHeader) Push(size int) tcpipbuffer.View { return h.pk.push(h.typ, size) } @@ -349,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View { // size, consumed will be false, and the state of h will not be affected. // Callers may only call one of Push or Consume once on each header in the // lifetime of the underlying packet. -func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { +func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) { return h.pk.consume(h.typ, size) } @@ -360,54 +387,84 @@ type PacketData struct { // PullUp returns a contiguous view of size bytes from the beginning of d. // Callers should not write to or keep the view for later use. -func (d PacketData) PullUp(size int) (buffer.View, bool) { - return d.pk.data.PullUp(size) +func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { + return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { - d.pk.data.TrimFront(count) +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { + if !d.pk.buf.Remove(d.pk.dataOffset(), count) { + panic("count > d.Size()") + } } // CapLength reduces d to at most length bytes. func (d PacketData) CapLength(length int) { - d.pk.data.CapLength(length) + if length < 0 { + panic("length < 0") + } + if currLength := d.Size(); currLength > length { + trim := currLength - length + d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim) + } } // Views returns the underlying storage of d in a slice of Views. Caller should // not modify the returned slice. -func (d PacketData) Views() []buffer.View { - return d.pk.data.Views() +func (d PacketData) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := d.pk.dataOffset() + d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views } // AppendView appends v into d, taking the ownership of v. -func (d PacketData) AppendView(v buffer.View) { - d.pk.data.AppendView(v) +func (d PacketData) AppendView(v tcpipbuffer.View) { + d.pk.buf.AppendOwned(v) } -// ReadFromData moves at most count bytes from the beginning of srcData to the -// end of d and returns the number of bytes moved. -func (d PacketData) ReadFromData(srcData PacketData, count int) int { - return srcData.pk.data.ReadToVV(&d.pk.data, count) +// MergeFragment appends the data portion of frag to dst. It takes ownership of +// frag and frag should not be used again. +func MergeFragment(dst, frag *PacketBuffer) { + frag.buf.TrimFront(int64(frag.dataOffset())) + dst.buf.Merge(frag.buf) } // ReadFromVV moves at most count bytes from the beginning of srcVV to the end // of d and returns the number of bytes moved. -func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { - return srcVV.ReadToVV(&d.pk.data, count) +func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int { + done := 0 + for _, v := range srcVV.Views() { + if len(v) < count { + count -= len(v) + done += len(v) + d.pk.buf.AppendOwned(v) + } else { + v = v[:count] + count -= len(v) + done += len(v) + d.pk.buf.Append(v) + break + } + } + srcVV.TrimFront(done) + return done } // Size returns the number of bytes in the data payload of the packet. func (d PacketData) Size() int { - return d.pk.data.Size() + return int(d.pk.buf.Size()) - d.pk.dataOffset() } // AsRange returns a Range representing the current data payload of the packet. func (d PacketData) AsRange() Range { return Range{ pk: d.pk, - offset: d.pk.HeaderSize(), + offset: d.pk.dataOffset(), length: d.Size(), } } @@ -417,17 +474,12 @@ func (d PacketData) AsRange() Range { // // This method exists for compatibility between PacketBuffer and VectorisedView. // It may be removed later and should be used with care. -func (d PacketData) ExtractVV() buffer.VectorisedView { - return d.pk.data -} - -// Replace replaces the data portion of the packet with vv, taking the ownership -// of vv. -// -// This method exists for compatibility between PacketBuffer and VectorisedView. -// It may be removed later and should be used with care. -func (d PacketData) Replace(vv buffer.VectorisedView) { - d.pk.data = vv +func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView { + var vv tcpipbuffer.VectorisedView + d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) { + vv.AppendView(v) + }) + return vv } // Range represents a contiguous subportion of a PacketBuffer. @@ -471,9 +523,9 @@ func (r Range) Capped(max int) Range { // AsView returns the backing storage of r if possible. It will allocate a new // View if r spans multiple pieces internally. Caller should not write to the // returned View in any way. -func (r Range) AsView() buffer.View { +func (r Range) AsView() tcpipbuffer.View { var allocated bool - var v buffer.View + var v tcpipbuffer.View r.iterate(func(b []byte) { if v == nil { // v has not been assigned, allowing first view to be returned. @@ -494,7 +546,7 @@ func (r Range) AsView() buffer.View { } // ToOwnedView returns a owned copy of data in r. -func (r Range) ToOwnedView() buffer.View { +func (r Range) ToOwnedView() tcpipbuffer.View { if r.length == 0 { return nil } @@ -515,63 +567,7 @@ func (r Range) Checksum() uint16 { // iterate calls fn for each piece in r. fn is always called with a non-empty // slice. func (r Range) iterate(fn func([]byte)) { - w := window{ - offset: r.offset, - length: r.length, - } - // Header portion. - for i := range r.pk.headers { - if b := w.process(r.pk.headers[i].buf); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - // Data portion. - if !w.isDone() { - for _, v := range r.pk.data.Views() { - if b := w.process(v); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - } -} - -// window represents contiguous region of byte stream. User would call process() -// to input bytes, and obtain a subslice that is inside the window. -type window struct { - offset int - length int -} - -// isDone returns true if the window has passed and further process() calls will -// always return an empty slice. This can be used to end processing early. -func (w *window) isDone() bool { - return w.length == 0 -} - -// process feeds b in and returns a subslice that is inside the window. The -// returned slice will be a subslice of b, and it does not keep b after method -// returns. This method may return an empty slice if nothing in b is inside the -// window. -func (w *window) process(b []byte) (inWindow []byte) { - if w.offset >= len(b) { - w.offset -= len(b) - return nil - } - if w.offset > 0 { - b = b[w.offset:] - w.offset = 0 - } - if w.length < len(b) { - b = b[:w.length] - } - w.length -= len(b) - return b + r.pk.buf.SubApply(r.offset, r.length, fn) } // PayloadSince returns packet payload starting from and including a particular @@ -579,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) { // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. -func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.data.Size() - for _, hinfo := range h.pk.headers[h.typ:] { - size += len(hinfo.buf) +func PayloadSince(h PacketHeader) tcpipbuffer.View { + offset := h.pk.headerOffset() + for i := headerType(0); i < h.typ; i++ { + offset += h.pk.headers[i].length } - - v := make(buffer.View, 0, size) - - for _, hinfo := range h.pk.headers[h.typ:] { - v = append(v, hinfo.buf...) - } - - for _, view := range h.pk.data.Views() { - v = append(v, view...) - } - - return v + return Range{ + pk: h.pk, + offset: offset, + length: int(h.pk.buf.Size()) - offset, + }.ToOwnedView() } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..a8da34992 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,70 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + +func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := concatViews(network, data) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) + + // 2. Consume network header, with a number of bytes too large. + gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) + if ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) + } + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +435,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -437,23 +475,8 @@ func TestPacketBufferData(t *testing.T) { checkData(t, pkt, []byte(tc.data+s)) }) - // ReadFromData/VV + // ReadFromVV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { - s := "TO READ" - otherPkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv(s, s), - }) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromData(otherPkt.Data(), n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) @@ -480,20 +503,41 @@ func TestPacketBufferData(t *testing.T) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) - - // Replace - t.Run("Replace", func(t *testing.T) { - s := "REPLACED" - - pkt := tc.makePkt(t) - pkt.Data().Replace(vv(s)) - - checkData(t, pkt, []byte(s)) - }) }) } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +554,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { @@ -540,7 +574,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7ad206f6d..85bb87b4b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,6 +55,9 @@ type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool + + // IsForwardedPacket is true if the packet is being forwarded. + IsForwardedPacket bool } // TransportErrorKind enumerates error types that are handled by the transport @@ -655,9 +658,9 @@ type IPNetworkEndpointStats interface { IPStats() *tcpip.IPStats } -// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. -type ForwardingNetworkProtocol interface { - NetworkProtocol +// ForwardingNetworkEndpoint is a network endpoint that may forward packets. +type ForwardingNetworkEndpoint interface { + NetworkEndpoint // Forwarding returns the forwarding configuration. Forwarding() bool @@ -756,11 +759,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -1047,10 +1045,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ecde5995..f17c04277 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -440,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool { // If the source NIC and outgoing NIC are different, make sure the stack has // forwarding enabled, or the packet will be handled locally. - if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { + if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { return false } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 843118b13..8814f45a6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,6 +29,7 @@ import ( "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -65,10 +66,10 @@ type ResumableEndpoint interface { } // uniqueIDGenerator is a default unique ID generator. -type uniqueIDGenerator uint64 +type uniqueIDGenerator atomicbitops.AlignedAtomicUint64 func (u *uniqueIDGenerator) UniqueID() uint64 { - return atomic.AddUint64((*uint64)(u), 1) + return ((*atomicbitops.AlignedAtomicUint64)(u)).Add(1) } // Stack is a networking stack, with all supported protocols, NICs, and route @@ -94,8 +95,9 @@ type Stack struct { } } - mu sync.RWMutex - nics map[tcpip.NICID]*nic + mu sync.RWMutex + nics map[tcpip.NICID]*nic + defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{} // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -322,7 +324,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -347,22 +349,23 @@ func New(opts Options) *Stack { } s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -491,37 +494,61 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { - protocol, ok := s.networkProtocols[protocolNum] +// SetNICForwarding enables or disables packet forwarding on the specified NIC +// for the passed protocol. +func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrUnknownProtocol{} + return &tcpip.ErrUnknownNICID{} } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + return nic.setForwarding(protocol, enable) +} + +// NICForwarding returns the forwarding configuration for the specified NIC. +func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrNotSupported{} + return false, &tcpip.ErrUnknownNICID{} } - forwardingProtocol.SetForwarding(enable) - return nil + return nic.forwarding(protocol) } -// Forwarding returns true if packet forwarding between NICs is enabled for the -// passed protocol. -func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { - protocol, ok := s.networkProtocols[protocolNum] - if !ok { - return false +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + doneOnce := false + for id, nic := range s.nics { + if err := nic.setForwarding(protocol, enable); err != nil { + // Expect forwarding to be settable on all interfaces if it was set on + // one. + if doneOnce { + panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err)) + } + + return err + } + + doneOnce = true } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) - if !ok { - return false + if enable { + s.defaultForwardingEnabled[protocol] = struct{}{} + } else { + delete(s.defaultForwardingEnabled, protocol) } - return forwardingProtocol.Forwarding() + return nil } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in @@ -658,6 +685,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) + for proto := range s.defaultForwardingEnabled { + if err := n.setForwarding(proto, true); err != nil { + panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err)) + } + } s.nics[id] = n if !opts.Disabled { return n.enable() @@ -785,6 +817,10 @@ type NICInfo struct { // value sent in haType field of an ARP Request sent by this NIC and the // value expected in the haType field of an ARP response. ARPHardwareType header.ARPHardwareType + + // Forwarding holds the forwarding status for each network endpoint that + // supports forwarding. + Forwarding map[tcpip.NetworkProtocolNumber]bool } // HasNIC returns true if the NICID is defined in the stack. @@ -814,7 +850,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { netStats[proto] = netEP.Stats() } - nics[id] = NICInfo{ + info := NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), @@ -824,7 +860,23 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), } + + for proto := range s.networkProtocols { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + info.Forwarding[proto] = forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } + } + + nics[id] = info } return nics } @@ -1028,6 +1080,20 @@ func (s *Stack) HandleLocal() bool { return s.handleLocal } +func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + return forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + return false + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1080,7 +1146,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal + onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1119,7 +1185,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // requirement to do this from any RFC but simply a choice made to better // follow a strong host model which the netstack follows at the time of // writing. - if canForward && chosenRoute == (tcpip.Route{}) { + if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) { chosenRoute = route } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8ead3b8df..02d54d29b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct { mu struct { sync.RWMutex - enabled bool + enabled bool + forwarding bool } nic stack.NetworkInterface @@ -138,11 +139,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -225,11 +228,6 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -298,15 +296,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -3020,7 +3018,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4218,8 +4216,8 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) @@ -4273,8 +4271,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go new file mode 100644 index 000000000..7ce43a68e --- /dev/null +++ b/pkg/tcpip/stdclock.go @@ -0,0 +1,130 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +// stdClock implements Clock with the time package. +// +// +stateify savable +type stdClock struct { + // baseTime holds the time when the clock was constructed. + // + // This value is used to calculate the monotonic time from the time package. + // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks, + // + // Operating systems provide both a “wall clock,” which is subject to + // changes for clock synchronization, and a “monotonic clock,” which is not. + // The general rule is that the wall clock is for telling time and the + // monotonic clock is for measuring time. Rather than split the API, in this + // package the Time returned by time.Now contains both a wall clock reading + // and a monotonic clock reading; later time-telling operations use the wall + // clock reading, but later time-measuring operations, specifically + // comparisons and subtractions, use the monotonic clock reading. + // + // ... + // + // If Times t and u both contain monotonic clock readings, the operations + // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using + // the monotonic clock readings alone, ignoring the wall clock readings. If + // either t or u contains no monotonic clock reading, these operations fall + // back to using the wall clock readings. + // + // Given the above, we can safely conclude that time.Since(baseTime) will + // return monotonically increasing values if we use time.Now() to set baseTime + // at the time of clock construction. + // + // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per + // https://golang.org/pkg/time/#Since. + baseTime time.Time `state:"nosave"` + + // monotonicOffset is the offset applied to the calculated monotonic time. + // + // monotonicOffset is assigned maxMonotonic after restore so that the + // monotonic time will continue from where it "left off" before saving as part + // of S/R. + monotonicOffset int64 `state:"nosave"` + + // monotonicMU protects maxMonotonic. + monotonicMU sync.Mutex `state:"nosave"` + maxMonotonic int64 +} + +// NewStdClock returns an instance of a clock that uses the time package. +func NewStdClock() Clock { + return &stdClock{ + baseTime: time.Now(), + } +} + +var _ Clock = (*stdClock)(nil) + +// NowNanoseconds implements Clock.NowNanoseconds. +func (*stdClock) NowNanoseconds() int64 { + return time.Now().UnixNano() +} + +// NowMonotonic implements Clock.NowMonotonic. +func (s *stdClock) NowMonotonic() int64 { + sinceBase := time.Since(s.baseTime) + if sinceBase < 0 { + panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) + } + + monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + + // Monotonic time values must never decrease. + if monotonicValue > s.maxMonotonic { + s.maxMonotonic = monotonicValue + } + + return s.maxMonotonic +} + +// AfterFunc implements Clock.AfterFunc. +func (*stdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/stdclock_state.go b/pkg/tcpip/stdclock_state.go new file mode 100644 index 000000000..795db9181 --- /dev/null +++ b/pkg/tcpip/stdclock_state.go @@ -0,0 +1,26 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "time" + +// afterLoad is invoked by stateify. +func (s *stdClock) afterLoad() { + s.baseTime = time.Now() + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + s.monotonicOffset = s.maxMonotonic +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0ba71b62e..797778e08 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -37,9 +37,9 @@ import ( "reflect" "strconv" "strings" - "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -73,7 +73,7 @@ type Clock interface { // nanoseconds since the Unix epoch. NowNanoseconds() int64 - // NowMonotonic returns a monotonic time value. + // NowMonotonic returns a monotonic time value at nanosecond resolution. NowMonotonic() int64 // AfterFunc waits for the duration to elapse and then calls f in its own @@ -1107,6 +1107,7 @@ const ( // LingerOption is used by SetSockOpt/GetSockOpt to set/get the // duration for which a socket lingers before returning from Close. // +// +marshal // +stateify savable type LingerOption struct { Enabled bool @@ -1219,7 +1220,7 @@ type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. type StatCounter struct { - count uint64 + count atomicbitops.AlignedAtomicUint64 } // Increment adds one to the counter. @@ -1234,12 +1235,12 @@ func (s *StatCounter) Decrement() { // Value returns the current value of the counter. func (s *StatCounter) Value(name ...string) uint64 { - return atomic.LoadUint64(&s.count) + return s.count.Load() } // IncrementBy increments the counter by v. func (s *StatCounter) IncrementBy(v uint64) { - atomic.AddUint64(&s.count, v) + s.count.Add(v) } func (s *StatCounter) String() string { @@ -1527,6 +1528,42 @@ type IGMPStats struct { // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } +// IPForwardingStats collects stats related to IP forwarding (both v4 and v6). +type IPForwardingStats struct { + // LINT.IfChange(IPForwardingStats) + + // Unrouteable is the number of IP packets received which were dropped + // because a route to their destination could not be constructed. + Unrouteable *StatCounter + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL *StatCounter + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource *StatCounter + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination *StatCounter + + // PacketTooBig is the number of IP packets which were dropped because they + // were too big for the outgoing MTU. + PacketTooBig *StatCounter + + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem *StatCounter + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors *StatCounter + + // LINT.ThenChange(network/internal/ip/stats.go:multiCounterIPForwardingStats) +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // LINT.IfChange(IPStats) @@ -1534,6 +1571,10 @@ type IPStats struct { // PacketsReceived is the number of IP packets received from the link layer. PacketsReceived *StatCounter + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived *StatCounter + // DisabledPacketsReceived is the number of IP packets received from the link // layer when the IP layer is disabled. DisabledPacketsReceived *StatCounter @@ -1573,6 +1614,10 @@ type IPStats struct { // chain. IPTablesInputDropped *StatCounter + // IPTablesForwardDropped is the number of IP packets dropped in the Forward + // chain. + IPTablesForwardDropped *StatCounter + // IPTablesOutputDropped is the number of IP packets dropped in the Output // chain. IPTablesOutputDropped *StatCounter @@ -1595,6 +1640,9 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + // Forwarding collects stats related to IP forwarding. + Forwarding IPForwardingStats + // LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats) } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index d4f7bb5ff..ab2dab60c 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -31,12 +31,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/udp", ], ) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..92fa6257d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -16,6 +16,7 @@ package forward_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -34,6 +35,39 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) +} + +func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) +} + +func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) +} + +func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) +} + func TestForwarding(t *testing.T) { const listenPort = 8080 @@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 - ttl = 64 ) var ( ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") ) - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) - } - - v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) - } - tests := []struct { name string srcAddr, dstAddr tcpip.Address @@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) }, }, { @@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) }, }, @@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) }, }, { @@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) }, }, } @@ -475,11 +480,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ @@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) { }) } } + +func TestPerInterfaceForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ) + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + // ARP is not used in this test but it is a network protocol that does + // not support forwarding. We install the protocol to make sure that + // forwarding information for a NIC is only reported for network + // protocols that support forwarding. + arp.NewProtocol, + + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + for _, add := range [...]struct { + nicID tcpip.NICID + addr tcpip.ProtocolAddress + }{ + { + nicID: nicID1, + addr: utils.RouterNIC1IPv4Addr, + }, + { + nicID: nicID1, + addr: utils.RouterNIC1IPv6Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv4Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv6Addr, + }, + } { + if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + } + } + + // Only enable forwarding on NIC1 and make sure that only packets arriving + // on NIC1 are forwarded. + for _, netProto := range netProtos { + if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { + t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) + } + } + + nicsInfo := s.NICInfo() + for _, subTest := range [...]struct { + nicID tcpip.NICID + nicEP *channel.Endpoint + otherNICID tcpip.NICID + otherNICEP *channel.Endpoint + expectForwarding bool + }{ + { + nicID: nicID1, + nicEP: e1, + otherNICID: nicID2, + otherNICEP: e2, + expectForwarding: true, + }, + { + nicID: nicID2, + nicEP: e2, + otherNICID: nicID2, + otherNICEP: e1, + expectForwarding: false, + }, + } { + t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { + nicInfo, ok := nicsInfo[subTest.nicID] + if !ok { + t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) + } else { + forwarding := make(map[tcpip.NetworkProtocolNumber]bool) + for _, netProto := range netProtos { + forwarding[netProto] = subTest.expectForwarding + } + + if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { + t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: subTest.otherNICID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: subTest.otherNICID, + }, + }) + + test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) + if p, ok := subTest.nicEP.Read(); ok { + t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) + } + if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { + t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) + } else if subTest.expectForwarding { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index c61d4e788..07ba2b837 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -19,12 +19,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) { }) } } + +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoReply(e, src, dst, ttl) +} + +func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoReply(e, src, dst, ttl) +} + +func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) +} + +func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) +} + +func TestForwardingHook(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + local bool + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 remote", + netProto: ipv4.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 local", + netProto: ipv4.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr.Address, + rx: rxICMPv4EchoReply, + }, + { + name: "IPv6 remote", + netProto: ipv6.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 local", + netProto: ipv6.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + rx: rxICMPv6EchoReply, + }, + } + + setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[stack.Forward] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } + } + + boolToInt := func(v bool) uint64 { + if v { + return 1 + } + return 0 + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectForward bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectForward: true, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + expectForward: false, + }, + { + name: "Drop with input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + expectForward: false, + }, + { + name: "Drop with output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + expectForward: false, + }, + { + name: "Drop with input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + expectForward: false, + }, + + { + name: "Drop with other input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + expectForward: true, + }, + { + name: "Drop with input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + expectForward: true, + }, + + { + name: "Drop with inverted input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + expectForward: true, + }, + { + name: "Drop with inverted output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + expectForward: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + expectTransmitPacket := subTest.expectForward && !test.local + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { + t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) + } + if got := ip1Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) + } + if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + p, ok := e2.Read() + if ok != expectTransmitPacket { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket) + } + if expectTransmitPacket { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..2e6ae55ea 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { @@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }) } -// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on -// the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) + pkt.SetType(ty) pkt.SetCode(header.ICMPv4UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, 0)) @@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) })) } -// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo) +} + +// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on +// the provided endpoint. +func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply) +} + +func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetType(ty) pkt.SetCode(header.ICMPv6UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) Data: hdr.View().ToVectorisedView(), })) } + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest) +} + +// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on +// the provided endpoint. +func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply) +} diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD index 472545a5d..02ee86ff1 100644 --- a/pkg/tcpip/testutil/BUILD +++ b/pkg/tcpip/testutil/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "testutil", testonly = True, - srcs = ["testutil.go"], + srcs = [ + "testutil.go", + "testutil_unsafe.go", + ], visibility = ["//visibility:public"], deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index 1aaed590f..f84d399fb 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -18,6 +18,8 @@ package testutil import ( "fmt" "net" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address { } return tcpip.Address(ip) } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go index 5ff764800..5ff764800 100644 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ b/pkg/tcpip/testutil/testutil_unsafe.go diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go deleted file mode 100644 index eeea97b12..000000000 --- a/pkg/tcpip/time_unsafe.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.9 -// +build !go1.18 - -// Check go:linkname function signatures when updating Go version. - -package tcpip - -import ( - "time" // Used with go:linkname. - _ "unsafe" // Required for go:linkname. -) - -// StdClock implements Clock with the time package. -// -// +stateify savable -type StdClock struct{} - -var _ Clock = (*StdClock)(nil) - -//go:linkname now time.now -func now() (sec int64, nsec int32, mono int64) - -// NowNanoseconds implements Clock.NowNanoseconds. -func (*StdClock) NowNanoseconds() int64 { - sec, nsec, _ := now() - return sec*1e9 + int64(nsec) -} - -// NowMonotonic implements Clock.NowMonotonic. -func (*StdClock) NowMonotonic() int64 { - _, _, mono := now() - return mono -} - -// AfterFunc implements Clock.AfterFunc. -func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { - return &stdTimer{ - t: time.AfterFunc(d, f), - } -} - -type stdTimer struct { - t *time.Timer -} - -var _ Timer = (*stdTimer)(nil) - -// Stop implements Timer.Stop. -func (st *stdTimer) Stop() bool { - return st.t.Stop() -} - -// Reset implements Timer.Reset. -func (st *stdTimer) Reset(d time.Duration) { - st.t.Reset(d) -} - -// NewStdTimer returns a Timer implemented with the time package. -func NewStdTimer(t *time.Timer) Timer { - return &stdTimer{t: t} -} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index a82384c49..1633d0aeb 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -29,7 +29,7 @@ const ( ) func TestJobReschedule(t *testing.T) { - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var wg sync.WaitGroup var lock sync.Mutex @@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { wg.Done() }) job.Schedule(shortDuration) @@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) { func TestJobExecution(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) @@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(middleDuration) lock.Lock() @@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) { func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. @@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { job.Cancel() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9948f305b..8afde7fca 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -756,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 48417f192..0f20d3856 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -126,7 +126,15 @@ go_test( go_test( name = "tcp_test", size = "small", - srcs = ["timer_test.go"], + srcs = [ + "segment_test.go", + "timer_test.go", + ], library = ":tcp", - deps = ["//pkg/sleep"], + deps = [ + "//pkg/sleep", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 524d5cabf..5e03e7715 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -586,8 +586,14 @@ func (h *handshake) complete() tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + // Check for any ICMP errors notified to us. if n¬ifyError != 0 { - return h.ep.lastErrorLocked() + if err := h.ep.lastErrorLocked(); err != nil { + return err + } + // Flag the handshake failure as aborted if the lastError is + // cleared because of a socket layer call. + return &tcpip.ErrConnectionAborted{} } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -1362,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Reaching this point means that we successfully completed the 3-way - // handshake with our peer. - // + // handshake with our peer. The current endpoint state could be any state + // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a + // RST from the peer via the dispatcher fast path, before the loop is + // started. + if s := e.EndpointState(); !s.connected() { + switch s { + case StateClose, StateError: + // If the endpoint is in CLOSED/ERROR state, sender state has to be + // initialized if the endpoint was previously established. + if e.snd != nil { + break + } + fallthrough + default: + panic("endpoint was not established, current state " + s.String()) + } + } + // Completing the 3-way handshake is an indication that the route is valid // and the remote is reachable as the only way we can complete a handshake // is if our SYN reached the remote and their ACK reached us. diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index d6d68f128..f148d505d 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -611,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3a7b2d166..50d39cbad 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1280,6 +1280,12 @@ func (e *endpoint) LastError() tcpip.Error { return e.lastErrorLocked() } +// LastErrorLocked reads and clears lastError with e.mu held. +// Only to be used in tests. +func (e *endpoint) LastErrorLocked() tcpip.Error { + return e.lastErrorLocked() +} + // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() @@ -1595,7 +1601,7 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of -// receive buffer size. This is chosen arbitrairly. +// receive buffer size. This is chosen arbitrarily. // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index ee2c08cd6..133371455 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -148,6 +148,18 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { } newWnd = curWnd } + + // Apply silly-window avoidance when recovering from zero-window situation. + // Keep advertising zero receive window up until the new window reaches a + // threshold. + if r.rcvWnd == 0 && newWnd != 0 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if crossed, above := r.ep.windowCrossedACKThresholdLocked(int(newWnd), int(r.ep.ops.GetReceiveBufferSize())); !crossed && !above { + newWnd = 0 + } + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() + } + // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c28641be3..7e5ba6ef7 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -140,6 +140,15 @@ func (s *segment) clone() *segment { return t } +// merge merges data in oth and clears oth. +func (s *segment) merge(oth *segment) { + s.data.Append(oth.data) + s.dataMemSize = s.data.Size() + + oth.data = buffer.VectorisedView{} + oth.dataMemSize = oth.data.Size() +} + // flagIsSet checks if at least one flag in flags is set in s.flags. func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go new file mode 100644 index 000000000..486016fc0 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -0,0 +1,67 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type segmentSizeWants struct { + DataSize int + SegMemSize int +} + +func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) { + t.Helper() + got := segmentSizeWants{ + DataSize: seg.data.Size(), + SegMemSize: seg.segMemSize(), + } + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("%s differs (-want +got):\n%s", name, diff) + } +} + +func TestSegmentMerge(t *testing.T) { + id := stack.TransportEndpointID{} + seg1 := newOutgoingSegment(id, buffer.NewView(10)) + defer seg1.decRef() + seg2 := newOutgoingSegment(id, buffer.NewView(20)) + defer seg2.decRef() + + checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ + DataSize: 10, + SegMemSize: SegSize + 10, + }) + checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ + DataSize: 20, + SegMemSize: SegSize + 20, + }) + + seg1.merge(seg2) + + checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ + DataSize: 30, + SegMemSize: SegSize + 30, + }) + checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ + DataSize: 0, + SegMemSize: SegSize, + }) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 2b32cb7b2..f43e86677 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -716,15 +716,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // triggering bugs in poorly written DNS // implementations. var nextTooBig bool - for seg.Next() != nil && seg.Next().data.Size() != 0 { - if seg.data.Size()+seg.Next().data.Size() > available { + for nSeg := seg.Next(); nSeg != nil && nSeg.data.Size() != 0; nSeg = seg.Next() { + if seg.data.Size()+nSeg.data.Size() > available { nextTooBig = true break } - seg.data.Append(seg.Next().data) - - // Consume the segment that we just merged in. - s.writeList.Remove(seg.Next()) + seg.merge(nSeg) + s.writeList.Remove(nSeg) + nSeg.decRef() } if !nextTooBig && seg.data.Size() < available { // Segment is not full. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 3750b0691..9916182e3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for receive to be notified. select { case <-notifyRead: @@ -130,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -145,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) { // and stats updates. { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAborted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -159,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) { } } +// Test for ICMP error handling without completing handshake. +func TestConnectICMPError(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } + } + + syn := c.GetPacket() + checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) + + wep := ep.(interface { + StopWork() + ResumeWork() + LastErrorLocked() tcpip.Error + }) + + // Stop the protocol loop, ensure that the ICMP error is processed and + // the last ICMP error is read before the loop is resumed. This sanity + // tests the handshake completion logic on ICMP errors. + wep.StopWork() + + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) + + for { + if err := wep.LastErrorLocked(); err != nil { + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + } + break + } + time.Sleep(time.Millisecond) + } + + wep.ResumeWork() + + <-notifyCh + + // The stack would have unregistered the endpoint because of the ICMP error. + // Expect a RST for any subsequent packets sent to the endpoint. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, + AckNum: c.IRS + 1, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + func TestConnectIncrementActiveConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -202,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -393,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -936,8 +1006,8 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} { err := c.EP.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) } } @@ -1543,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1604,8 +1674,8 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -2473,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -2545,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3077,8 +3147,8 @@ func TestSetTTL(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3137,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3191,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3266,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3385,8 +3455,8 @@ loop: case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } break loop case <-time.After(1 * time.Second): @@ -3436,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) { var r bytes.Reader r.Reset(make([]byte, 10)) _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) } } @@ -4390,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) { var buf bytes.Buffer { _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { + t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) } } } @@ -4435,8 +4505,8 @@ func TestReusePort(t *testing.T) { } { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } c.EP.Close() @@ -4724,8 +4794,8 @@ func TestSelfConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -5428,7 +5498,7 @@ func TestListenBacklogFull(t *testing.T) { } lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ { + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } @@ -5452,7 +5522,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5469,7 +5539,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5481,7 +5551,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5794,7 +5864,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5865,7 +5935,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5881,7 +5951,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -6020,7 +6090,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { t.Fatalf("Accept failed: %s", err) } - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) @@ -6088,7 +6158,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6158,7 +6228,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6210,7 +6280,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6228,8 +6298,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } { err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { + t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) } } // Listening endpoint remains in listen state. @@ -6349,7 +6419,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } } @@ -6480,7 +6550,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } totalCopied += res.Count @@ -6672,7 +6742,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6791,7 +6861,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6898,7 +6968,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6988,7 +7058,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7062,7 +7132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7212,7 +7282,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7643,8 +7713,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Send data. This should result in an acceptable endpoint. @@ -7702,8 +7772,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Sleep for a little of the tcpDeferAccept timeout. diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 16f8c5212..53efecc5a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -1214,9 +1214,9 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.HWGSOSupported } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.GSONotSupported } } diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 4855a52fc..12fe98b16 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -82,10 +82,15 @@ func (p *profile) createProcess(c *Container) error { } // The root directory of this container's runtime. - root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + rootDir := fmt.Sprintf("/var/run/docker/runtime-%s/moby", c.runtime) + if _, err := os.Stat(rootDir); os.IsNotExist(err) { + // In docker v20+, due to https://github.com/moby/moby/issues/42345 the + // rootDir seems to always be the following. + rootDir = "/var/run/docker/runtime-runc/moby" + } - // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`. - args := []string{root, "debug"} + // Format is `runsc --root=rootDir debug --profile-*=file --duration=24h containerID`. + args := []string{fmt.Sprintf("--root=%s", rootDir), "debug"} for _, profileArg := range p.Types { outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) |