summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/netfilter.go28
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go2
-rw-r--r--pkg/abi/linux/netlink.go6
-rw-r--r--pkg/abi/linux/netlink_route.go6
-rw-r--r--pkg/abi/linux/socket.go16
-rw-r--r--pkg/bits/bits.go10
-rw-r--r--pkg/marshal/BUILD1
-rw-r--r--pkg/marshal/primitive/primitive.go75
-rw-r--r--pkg/marshal/util.go23
-rw-r--r--pkg/merkletree/merkletree.go4
-rw-r--r--pkg/p9/client_file.go16
-rw-r--r--pkg/p9/file.go60
-rw-r--r--pkg/p9/handlers.go28
-rw-r--r--pkg/p9/messages.go84
-rw-r--r--pkg/p9/p9.go28
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD1
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go329
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go386
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go12
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go42
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go39
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go4
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD4
-rw-r--r--pkg/sentry/socket/control/control.go66
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go23
-rw-r--r--pkg/sentry/socket/hostinet/stack.go29
-rw-r--r--pkg/sentry/socket/netfilter/BUILD4
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go13
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go7
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go7
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go11
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go9
-rw-r--r--pkg/sentry/socket/netfilter/targets.go38
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go8
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go8
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/message.go40
-rw-r--r--pkg/sentry/socket/netlink/message_test.go18
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go30
-rw-r--r--pkg/sentry/socket/netlink/socket.go21
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go40
-rw-r--r--pkg/sentry/socket/socket.go15
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/linux64_amd64.go1
-rw-r--r--pkg/sentry/strace/linux64_arm64.go1
-rw-r--r--pkg/sentry/strace/socket.go32
-rw-r--r--pkg/sentry/syscalls/epoll.go8
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go56
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go52
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go2
-rw-r--r--pkg/sentry/vfs/file_description.go14
-rw-r--r--pkg/sentry/vfs/opath.go4
-rw-r--r--pkg/sentry/vfs/resolving_path.go84
-rw-r--r--pkg/sentry/vfs/vfs.go83
-rw-r--r--pkg/tcpip/BUILD3
-rw-r--r--pkg/tcpip/link/channel/channel.go11
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go55
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go8
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go13
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go16
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go20
-rw-r--r--pkg/tcpip/stack/forwarding_test.go9
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go2
-rw-r--r--pkg/tcpip/stack/packet_buffer.go7
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go140
-rw-r--r--pkg/tcpip/stack/registration.go24
-rw-r--r--pkg/tcpip/stack/route.go10
-rw-r--r--pkg/tcpip/stack/stack.go2
-rw-r--r--pkg/tcpip/stack/stack_test.go6
-rw-r--r--pkg/tcpip/stdclock.go130
-rw-r--r--pkg/tcpip/stdclock_state.go26
-rw-r--r--pkg/tcpip/tcpip.go3
-rw-r--r--pkg/tcpip/time_unsafe.go75
-rw-r--r--pkg/tcpip/timer_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go4
92 files changed, 1799 insertions, 794 deletions
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 35c632168..3fd05483a 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -245,6 +245,8 @@ const SizeOfXTCounters = 16
// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
// exposing different data to the user and kernel, but this struct holds only
// the user data.
+//
+// +marshal
type XTEntryMatch struct {
MatchSize uint16
Name ExtensionName
@@ -284,6 +286,8 @@ const SizeOfXTGetRevision = 30
// include/uapi/linux/netfilter/x_tables.h. That struct contains a union
// exposing different data to the user and kernel, but this struct holds only
// the user data.
+//
+// +marshal
type XTEntryTarget struct {
TargetSize uint16
Name ExtensionName
@@ -306,6 +310,8 @@ type KernelXTEntryTarget struct {
// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTStandardTarget struct {
Target XTEntryTarget
// A positive verdict indicates a jump, and is the offset from the
@@ -322,6 +328,8 @@ const SizeOfXTStandardTarget = 40
// beginning of user-defined chains by putting the name of the chain in
// ErrorName. It corresponds to struct xt_error_target in
// include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTErrorTarget struct {
Target XTEntryTarget
Name ErrorName
@@ -349,6 +357,8 @@ const (
// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
// network byte order.
+//
+// +marshal
type NfNATIPV4Range struct {
Flags uint32
MinIP [4]byte
@@ -359,6 +369,8 @@ type NfNATIPV4Range struct {
// NfNATIPV4MultiRangeCompat corresponds to struct
// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+//
+// +marshal
type NfNATIPV4MultiRangeCompat struct {
RangeSize uint32
RangeIPV4 NfNATIPV4Range
@@ -366,6 +378,8 @@ type NfNATIPV4MultiRangeCompat struct {
// XTRedirectTarget triggers a redirect when reached.
// Adding 4 bytes of padding to make the struct 8 byte aligned.
+//
+// +marshal
type XTRedirectTarget struct {
Target XTEntryTarget
NfRange NfNATIPV4MultiRangeCompat
@@ -377,6 +391,8 @@ const SizeOfXTRedirectTarget = 56
// XTSNATTarget triggers Source NAT when reached.
// Adding 4 bytes of padding to make the struct 8 byte aligned.
+//
+// +marshal
type XTSNATTarget struct {
Target XTEntryTarget
NfRange NfNATIPV4MultiRangeCompat
@@ -463,6 +479,8 @@ var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTReplace struct {
Name TableName
ValidHooks uint32
@@ -502,6 +520,8 @@ func (tn TableName) String() string {
// ErrorName holds the name of a netfilter error. These can also hold
// user-defined chains.
+//
+// +marshal
type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
// String implements fmt.Stringer.
@@ -520,6 +540,8 @@ func goString(cstring []byte) string {
// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
// in include/uapi/linux/netfilter/xt_tcpudp.h.
+//
+// +marshal
type XTTCP struct {
// SourcePortStart specifies the inclusive start of the range of source
// ports to which the matcher applies.
@@ -573,6 +595,8 @@ const (
// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
// in include/uapi/linux/netfilter/xt_tcpudp.h.
+//
+// +marshal
type XTUDP struct {
// SourcePortStart is the inclusive start of the range of source ports
// to which the matcher applies.
@@ -613,6 +637,8 @@ const (
// IPTOwnerInfo holds data for matching packets with owner. It corresponds
// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+//
+// +marshal
type IPTOwnerInfo struct {
// UID is user id which created the packet.
UID uint32
@@ -634,7 +660,7 @@ type IPTOwnerInfo struct {
Match uint8
// Invert flips the meaning of Match field.
- Invert uint8
+ Invert uint8 `marshal:"unaligned"`
}
// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index f7c70b430..b088b207c 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -264,6 +264,8 @@ const (
// NFNATRange corresponds to struct nf_nat_range in
// include/uapi/linux/netfilter/nf_nat.h.
+//
+// +marshal
type NFNATRange struct {
Flags uint32
MinAddr Inet6Addr
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
index b41f94a69..232fee67e 100644
--- a/pkg/abi/linux/netlink.go
+++ b/pkg/abi/linux/netlink.go
@@ -53,6 +53,8 @@ type SockAddrNetlink struct {
const SockAddrNetlinkSize = 12
// NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkMessageHeader struct {
Length uint32
Type uint16
@@ -99,6 +101,8 @@ const NLMSG_ALIGNTO = 4
// NetlinkAttrHeader is the header of a netlink attribute, followed by payload.
//
// This is struct nlattr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkAttrHeader struct {
Length uint16
Type uint16
@@ -126,6 +130,8 @@ const (
)
// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h.
+//
+// +marshal
type NetlinkErrorMessage struct {
Error int32
Header NetlinkMessageHeader
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index ceda0a8d3..581a11b24 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -85,6 +85,8 @@ const (
)
// InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h.
+//
+// +marshal
type InterfaceInfoMessage struct {
Family uint8
_ uint8
@@ -164,6 +166,8 @@ const (
)
// InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h.
+//
+// +marshal
type InterfaceAddrMessage struct {
Family uint8
PrefixLen uint8
@@ -193,6 +197,8 @@ const (
)
// RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h.
+//
+// +marshal
type RouteMessage struct {
Family uint8
DstLen uint8
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 185eee0bb..95871b8a5 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -15,7 +15,6 @@
package linux
import (
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal"
)
@@ -251,18 +250,24 @@ type SockAddrInet struct {
}
// Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h.
+//
+// +marshal
type Inet6MulticastRequest struct {
MulticastAddr Inet6Addr
InterfaceIndex int32
}
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
+//
+// +marshal
type InetMulticastRequest struct {
MulticastAddr InetAddr
InterfaceAddr InetAddr
}
// InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h.
+//
+// +marshal
type InetMulticastRequestWithNIC struct {
InetMulticastRequest
InterfaceIndex int32
@@ -491,7 +496,7 @@ type TCPInfo struct {
}
// SizeOfTCPInfo is the binary size of a TCPInfo struct.
-var SizeOfTCPInfo = int(binary.Size(TCPInfo{}))
+var SizeOfTCPInfo = (*TCPInfo)(nil).SizeBytes()
// Control message types, from linux/socket.h.
const (
@@ -502,6 +507,8 @@ const (
// A ControlMessageHeader is the header for a socket control message.
//
// ControlMessageHeader represents struct cmsghdr from linux/socket.h.
+//
+// +marshal
type ControlMessageHeader struct {
Length uint64
Level int32
@@ -510,7 +517,7 @@ type ControlMessageHeader struct {
// SizeOfControlMessageHeader is the binary size of a ControlMessageHeader
// struct.
-var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
+var SizeOfControlMessageHeader = (*ControlMessageHeader)(nil).SizeBytes()
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
@@ -527,6 +534,7 @@ type ControlMessageCredentials struct {
//
// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
//
+// +marshal
// +stateify savable
type ControlMessageIPPacketInfo struct {
NIC int32
@@ -536,7 +544,7 @@ type ControlMessageIPPacketInfo struct {
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
-var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
+var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes()
// A ControlMessageRights is an SCM_RIGHTS socket control message.
type ControlMessageRights []int32
diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go
index a26433ad6..d16448c3d 100644
--- a/pkg/bits/bits.go
+++ b/pkg/bits/bits.go
@@ -14,3 +14,13 @@
// Package bits includes all bit related types and operations.
package bits
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD
index 7cd89e639..7a5002176 100644
--- a/pkg/marshal/BUILD
+++ b/pkg/marshal/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"marshal.go",
"marshal_impl_util.go",
+ "util.go",
],
visibility = [
"//:sandbox",
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
index 32c8ed138..6f38992b7 100644
--- a/pkg/marshal/primitive/primitive.go
+++ b/pkg/marshal/primitive/primitive.go
@@ -125,6 +125,81 @@ func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) {
var _ marshal.Marshallable = (*ByteSlice)(nil)
+// The following set of functions are convenient shorthands for wrapping a
+// built-in type in a marshallable primitive type. For example:
+//
+// func useMarshallable(m marshal.Marshallable) { ... }
+//
+// // Compare:
+//
+// buf = []byte{...}
+// // useMarshallable(&primitive.ByteSlice(buf)) // Not allowed, can't address temp value.
+// bufP := primitive.ByteSlice(buf)
+// useMarshallable(&bufP)
+//
+// // Vs:
+//
+// useMarshallable(AsByteSlice(buf))
+//
+// Note that the argument to these function escapes, so avoid using them on very
+// hot code paths. But generally if a function accepts an interface as an
+// argument, the argument escapes anyways.
+
+// AllocateInt8 returns x as a marshallable.
+func AllocateInt8(x int8) marshal.Marshallable {
+ p := Int8(x)
+ return &p
+}
+
+// AllocateUint8 returns x as a marshallable.
+func AllocateUint8(x uint8) marshal.Marshallable {
+ p := Uint8(x)
+ return &p
+}
+
+// AllocateInt16 returns x as a marshallable.
+func AllocateInt16(x int16) marshal.Marshallable {
+ p := Int16(x)
+ return &p
+}
+
+// AllocateUint16 returns x as a marshallable.
+func AllocateUint16(x uint16) marshal.Marshallable {
+ p := Uint16(x)
+ return &p
+}
+
+// AllocateInt32 returns x as a marshallable.
+func AllocateInt32(x int32) marshal.Marshallable {
+ p := Int32(x)
+ return &p
+}
+
+// AllocateUint32 returns x as a marshallable.
+func AllocateUint32(x uint32) marshal.Marshallable {
+ p := Uint32(x)
+ return &p
+}
+
+// AllocateInt64 returns x as a marshallable.
+func AllocateInt64(x int64) marshal.Marshallable {
+ p := Int64(x)
+ return &p
+}
+
+// AllocateUint64 returns x as a marshallable.
+func AllocateUint64(x uint64) marshal.Marshallable {
+ p := Uint64(x)
+ return &p
+}
+
+// AsByteSlice returns b as a marshallable. Note that this allocates a new slice
+// header, but does not copy the slice contents.
+func AsByteSlice(b []byte) marshal.Marshallable {
+ bs := ByteSlice(b)
+ return &bs
+}
+
// Below, we define some convenience functions for marshalling primitive types
// using the newtypes above, without requiring superfluous casts.
diff --git a/pkg/marshal/util.go b/pkg/marshal/util.go
new file mode 100644
index 000000000..c1e5475bd
--- /dev/null
+++ b/pkg/marshal/util.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package marshal
+
+// Marshal returns the serialized contents of m in a newly allocated
+// byte slice.
+func Marshal(m Marshallable) []byte {
+ buf := make([]byte, m.SizeBytes())
+ m.MarshalUnsafe(buf)
+ return buf
+}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 6450f664c..ac7868ad9 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -36,7 +36,6 @@ const (
)
// DigestSize returns the size (in bytes) of a digest.
-// TODO(b/156980949): Allow config SHA384.
func DigestSize(hashAlgorithm int) int {
switch hashAlgorithm {
case linux.FS_VERITY_HASH_ALG_SHA256:
@@ -69,7 +68,6 @@ func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool)
blockSize: hostarch.PageSize,
}
- // TODO(b/156980949): Allow config SHA384.
switch hashAlgorithms {
case linux.FS_VERITY_HASH_ALG_SHA256:
layout.digestSize = sha256DigestSize
@@ -429,8 +427,6 @@ func Verify(params *VerifyParams) (int64, error) {
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
- // TODO(b/162908070): Investigate possible issues with zero
- // padding the data.
if bytesRead < len(buf) {
for j := bytesRead; j < len(buf); j++ {
buf[j] = 0
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 7abc82e1b..28396b0ea 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At
return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil
}
+func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, unix.EBADF
+ }
+
+ if !versionSupportsTmultiGetAttr(c.client.version) {
+ return DefaultMultiGetAttr(c, names)
+ }
+
+ rmultigetattr := Rmultigetattr{}
+ if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil {
+ return nil, err
+ }
+ return rmultigetattr.Stats, nil
+}
+
// StatFS implements File.StatFS.
func (c *clientFile) StatFS() (FSStat, error) {
if atomic.LoadUint32(&c.closed) != 0 {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index c59c6a65b..97e0231d6 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -15,6 +15,8 @@
package p9
import (
+ "errors"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/fd"
)
@@ -72,6 +74,15 @@ type File interface {
// On the server, WalkGetAttr has a read concurrency guarantee.
WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error)
+ // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of
+ // path components similar to Walk(). If the first component name is empty,
+ // the current file is stat'd and included in the results. If the walk reaches
+ // a file that doesn't exist or not a directory, MultiGetAttr returns the
+ // partial result with no error.
+ //
+ // On the server, MultiGetAttr has a read concurrency guarantee.
+ MultiGetAttr(names []string) ([]FullStat, error)
+
// StatFS returns information about the file system associated with
// this file.
//
@@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error {
type DisallowServerCalls struct{}
// Renamed implements File.Renamed.
-func (*clientFile) Renamed(File, string) {
+func (*DisallowServerCalls) Renamed(File, string) {
panic("Renamed should not be called on the client")
}
+
+// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File.
+func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
+ stats := make([]FullStat, 0, len(names))
+ parent := start
+ mask := AttrMaskAll()
+ for i, name := range names {
+ if len(name) == 0 && i == 0 {
+ qid, valid, attr, err := parent.GetAttr(mask)
+ if err != nil {
+ return nil, err
+ }
+ stats = append(stats, FullStat{
+ QID: qid,
+ Valid: valid,
+ Attr: attr,
+ })
+ continue
+ }
+ qids, child, valid, attr, err := parent.WalkGetAttr([]string{name})
+ if parent != start {
+ _ = parent.Close()
+ }
+ if err != nil {
+ if errors.Is(err, unix.ENOENT) {
+ return stats, nil
+ }
+ return nil, err
+ }
+ stats = append(stats, FullStat{
+ QID: qids[0],
+ Valid: valid,
+ Attr: attr,
+ })
+ if attr.Mode.FileType() != ModeDirectory {
+ // Doesn't need to continue if entry is not a dir. Including symlinks
+ // that cannot be followed.
+ _ = child.Close()
+ break
+ }
+ parent = child
+ }
+ if parent != start {
+ _ = parent.Close()
+ }
+ return stats, nil
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 58312d0cc..758e11b13 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message {
}
return rchannel
}
+
+// handle implements handler.handle.
+func (t *Tmultigetattr) handle(cs *connState) message {
+ for i, name := range t.Names {
+ if len(name) == 0 && i == 0 {
+ // Empty name is allowed on the first entry to indicate that the current
+ // FID needs to be included in the result.
+ continue
+ }
+ if err := checkSafeName(name); err != nil {
+ return newErr(err)
+ }
+ }
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(unix.EBADF)
+ }
+ defer ref.DecRef()
+
+ var stats []FullStat
+ if err := ref.safelyRead(func() (err error) {
+ stats, err = ref.file.MultiGetAttr(t.Names)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rmultigetattr{Stats: stats}
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index cf13cbb69..2ff4694c0 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) {
// encode implements encoder.encode.
func (r *Rwalk) encode(b *buffer) {
b.Write16(uint16(len(r.QIDs)))
- for _, q := range r.QIDs {
- q.encode(b)
+ for i := range r.QIDs {
+ r.QIDs[i].encode(b)
}
}
@@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) {
r.Valid.encode(b)
r.Attr.encode(b)
b.Write16(uint16(len(r.QIDs)))
- for _, q := range r.QIDs {
- q.encode(b)
+ for i := range r.QIDs {
+ r.QIDs[i].encode(b)
}
}
@@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string {
return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
+// Tmultigetattr is a multi-getattr request.
+type Tmultigetattr struct {
+ // FID is the FID to be walked.
+ FID FID
+
+ // Names are the set of names to be walked.
+ Names []string
+}
+
+// decode implements encoder.decode.
+func (t *Tmultigetattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ n := b.Read16()
+ t.Names = t.Names[:0]
+ for i := 0; i < int(n); i++ {
+ t.Names = append(t.Names, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (t *Tmultigetattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write16(uint16(len(t.Names)))
+ for _, name := range t.Names {
+ b.WriteString(name)
+ }
+}
+
+// Type implements message.Type.
+func (*Tmultigetattr) Type() MsgType {
+ return MsgTmultigetattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tmultigetattr) String() string {
+ return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names)
+}
+
+// Rmultigetattr is a multi-getattr response.
+type Rmultigetattr struct {
+ // Stats are the set of FullStat returned for each of the names in the
+ // request.
+ Stats []FullStat
+}
+
+// decode implements encoder.decode.
+func (r *Rmultigetattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Stats = r.Stats[:0]
+ for i := 0; i < int(n); i++ {
+ var fs FullStat
+ fs.decode(b)
+ r.Stats = append(r.Stats, fs)
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rmultigetattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Stats)))
+ for i := range r.Stats {
+ r.Stats[i].encode(b)
+ }
+}
+
+// Type implements message.Type.
+func (*Rmultigetattr) Type() MsgType {
+ return MsgRmultigetattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rmultigetattr) String() string {
+ return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats)
+}
+
const maxCacheSize = 3
// msgFactory is used to reduce allocations by caching messages for reuse.
@@ -2717,6 +2791,8 @@ func init() {
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} })
msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} })
+ msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} })
+ msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} })
msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 648cf4b49..3d452a0bd 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -402,6 +402,8 @@ const (
MsgRallocate MsgType = 139
MsgTsetattrclunk MsgType = 140
MsgRsetattrclunk MsgType = 141
+ MsgTmultigetattr MsgType = 142
+ MsgRmultigetattr MsgType = 143
MsgTchannel MsgType = 250
MsgRchannel MsgType = 251
)
@@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) {
}
b.Write32(mask)
}
+
+// FullStat is used in the result of a MultiGetAttr call.
+type FullStat struct {
+ QID QID
+ Valid AttrMask
+ Attr Attr
+}
+
+// String implements fmt.Stringer.
+func (f *FullStat) String() string {
+ return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr)
+}
+
+// decode implements encoder.decode.
+func (f *FullStat) decode(b *buffer) {
+ f.QID.decode(b)
+ f.Valid.decode(b)
+ f.Attr.decode(b)
+}
+
+// encode implements encoder.encode.
+func (f *FullStat) encode(b *buffer) {
+ f.QID.encode(b)
+ f.Valid.encode(b)
+ f.Attr.encode(b)
+}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 8d7168ef5..950236162 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 12
+ highestSupportedVersion uint32 = 13
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool {
func versionSupportsTsetattrclunk(v uint32) bool {
return v >= 12
}
+
+// versionSupportsTmultiGetAttr returns true if version v supports
+// the TmultiGetAttr message.
+func versionSupportsTmultiGetAttr(v uint32) bool {
+ return v >= 13
+}
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 3f17fba49..9dac53c80 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -322,3 +322,12 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc
func (p *PageTables) MarkReadOnlyShared() {
p.readOnlyShared = true
}
+
+// PrefaultRootTable touches the root table page to be sure that its physical
+// pages are mapped.
+//
+//go:nosplit
+//go:noinline
+func (p *PageTables) PrefaultRootTable() PTE {
+ return p.root[0]
+}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 6d5258a9b..52879f871 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"host_named_pipe.go",
"p9file.go",
"regular_file.go",
+ "revalidate.go",
"save_restore.go",
"socket.go",
"special_file.go",
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 4b5621043..97ce80853 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
return ds
}
+// Precondition: !parent.isSynthetic() && !child.isSynthetic().
+func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) {
+ // The new child was added to parent and took a ref on the parent (hence
+ // parent can be removed from cache). A new child has 0 refs for now. So
+ // checkCachingLocked() should be called on both. Call it first on the parent
+ // as it may create space in the cache for child to be inserted - hence
+ // avoiding a cache eviction.
+ *ds = appendDentry(*ds, parent)
+ *ds = appendDentry(*ds, child)
+}
+
// Preconditions: ds != nil.
func putDentrySlice(ds *[]*dentry) {
// Allow dentries to be GC'd.
@@ -169,167 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
// * !rp.Done().
-// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up
-// to date.
+// * If !d.cachedMetadataAuthoritative(), then d and all children that are
+// part of rp must have been revalidated.
//
// Postconditions: The returned dentry's cached metadata is up to date.
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) {
if !d.isDir() {
- return nil, syserror.ENOTDIR
+ return nil, false, syserror.ENOTDIR
}
if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
- return nil, err
+ return nil, false, err
}
+ followedSymlink := false
afterSymlink:
name := rp.Component()
if name == "." {
rp.Advance()
- return d, nil
+ return d, followedSymlink, nil
}
if name == ".." {
if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
- return nil, err
+ return nil, false, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return d, nil
- }
- // We must assume that d.parent is correct, because if d has been moved
- // elsewhere in the remote filesystem so that its parent has changed,
- // we have no way of determining its new parent's location in the
- // filesystem.
- //
- // Call rp.CheckMount() before updating d.parent's metadata, since if
- // we traverse to another mount then d.parent's metadata is irrelevant.
- if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
- return nil, err
+ return d, followedSymlink, nil
}
- if d != d.parent && !d.cachedMetadataAuthoritative() {
- if err := d.parent.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, false, err
}
rp.Advance()
- return d.parent, nil
+ return d.parent, followedSymlink, nil
}
- child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds)
+ child, err := fs.getChildLocked(ctx, d, name, ds)
if err != nil {
- return nil, err
- }
- if child == nil {
- return nil, syserror.ENOENT
+ return nil, false, err
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
- return nil, err
+ return nil, false, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx, rp.Mount())
if err != nil {
- return nil, err
+ return nil, false, err
}
if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ return nil, false, err
}
+ followedSymlink = true
goto afterSymlink // don't check the current directory again
}
rp.Advance()
- return child, nil
+ return child, followedSymlink, nil
}
// getChildLocked returns a dentry representing the child of parent with the
-// given name. If no such child exists, getChildLocked returns (nil, nil).
+// given name. Returns ENOENT if the child doesn't exist.
//
// Preconditions:
// * fs.renameMu must be locked.
// * parent.dirMu must be locked.
// * parent.isDir().
// * name is not "." or "..".
-//
-// Postconditions: If getChildLocked returns a non-nil dentry, its cached
-// metadata is up to date.
-func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+// * dentry at name has been revalidated
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if len(name) > maxFilenameLen {
return nil, syserror.ENAMETOOLONG
}
- child, ok := parent.children[name]
- if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() {
- // Whether child is nil or not, it is cached information that is
- // assumed to be correct.
+ if child, ok := parent.children[name]; ok || parent.isSynthetic() {
+ if child == nil {
+ return nil, syserror.ENOENT
+ }
return child, nil
}
- // We either don't have cached information or need to verify that it's
- // still correct, either of which requires a remote lookup. Check if this
- // name is valid before performing the lookup.
- return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds)
-}
-// Preconditions: Same as getChildLocked, plus:
-// * !parent.isSynthetic().
-func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
- if child != nil {
- // Need to lock child.metadataMu because we might be updating child
- // metadata. We need to hold the lock *before* getting metadata from the
- // server and release it after updating local metadata.
- child.metadataMu.Lock()
- }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
- if err != nil && err != syserror.ENOENT {
- if child != nil {
- child.metadataMu.Unlock()
+ if err != nil {
+ if err == syserror.ENOENT {
+ parent.cacheNegativeLookupLocked(name)
}
return nil, err
}
- if child != nil {
- if !file.isNil() && qid.Path == child.qidPath {
- // The file at this path hasn't changed. Just update cached metadata.
- file.close(ctx)
- child.updateFromP9AttrsLocked(attrMask, &attr)
- child.metadataMu.Unlock()
- return child, nil
- }
- child.metadataMu.Unlock()
- if file.isNil() && child.isSynthetic() {
- // We have a synthetic file, and no remote file has arisen to
- // replace it.
- return child, nil
- }
- // The file at this path has changed or no longer exists. Mark the
- // dentry invalidated, and re-evaluate its caching status (i.e. if it
- // has 0 references, drop it). Wait to update parent.children until we
- // know what to replace the existing dentry with (i.e. one of the
- // returns below), to avoid a redundant map access.
- vfsObj.InvalidateDentry(ctx, &child.vfsd)
- if child.isSynthetic() {
- // Normally we don't mark invalidated dentries as deleted since
- // they may still exist (but at a different path), and also for
- // consistency with Linux. However, synthetic files are guaranteed
- // to become unreachable if their dentries are invalidated, so
- // treat their invalidation as deletion.
- child.setDeleted()
- parent.syntheticChildren--
- child.decRefNoCaching()
- parent.dirents = nil
- }
- *ds = appendDentry(*ds, child)
- }
- if file.isNil() {
- // No file exists at this path now. Cache the negative lookup if
- // allowed.
- parent.cacheNegativeLookupLocked(name)
- return nil, nil
- }
+
// Create a new dentry representing the file.
- child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
if err != nil {
file.close(ctx)
delete(parent.children, name)
return nil, err
}
parent.cacheNewChildLocked(child, name)
- // For now, child has 0 references, so our caller should call
- // child.checkCachingLocked(). parent gained a ref so we should also call
- // parent.checkCachingLocked() so it can be removed from the cache if needed.
- *ds = appendDentry(*ds, child)
- *ds = appendDentry(*ds, parent)
+ appendNewChildDentry(ds, parent, child)
return child, nil
}
@@ -344,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up
// to date.
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
}
d = next
+ if followedSymlink {
+ if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
+ }
}
if !d.isDir() {
return nil, syserror.ENOTDIR
@@ -364,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions: fs.renameMu must be locked.
func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
d := rp.Start().Impl().(*dentry)
- if !d.cachedMetadataAuthoritative() {
- // Get updated metadata for rp.Start() as required by fs.stepLocked().
- if err := d.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
+ if err := fs.revalidatePath(ctx, rp, d, ds); err != nil {
+ return nil, err
}
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
}
d = next
+ if followedSymlink {
+ if err := fs.revalidatePath(ctx, rp, d, ds); err != nil {
+ return nil, err
+ }
+ }
}
if rp.MustBeDir() && !d.isDir() {
return nil, syserror.ENOTDIR
@@ -397,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return err
- }
- }
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return err
@@ -421,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if parent.isDeleted() {
return syserror.ENOENT
}
+ if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil {
+ return err
+ }
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds)
- switch {
- case err != nil && err != syserror.ENOENT:
- return err
- case child != nil:
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+ // Check for existence only if caching information is available. Otherwise,
+ // don't check for existence just yet. We will check for existence if the
+ // checks for writability fail below. Existence check is done by the creation
+ // RPCs themselves.
+ if child, ok := parent.children[name]; ok && child != nil {
return syserror.EEXIST
}
+ checkExistence := func() error {
+ if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT {
+ return err
+ } else if child != nil {
+ return syserror.EEXIST
+ }
+ return nil
+ }
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
+ // Existence check takes precedence.
+ if existenceErr := checkExistence(); existenceErr != nil {
+ return existenceErr
+ }
return err
}
defer mnt.EndWrite()
if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ // Existence check takes precedence.
+ if existenceErr := checkExistence(); existenceErr != nil {
+ return existenceErr
+ }
return err
}
if !dir && rp.MustBeDir() {
@@ -489,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return err
- }
- }
parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return err
@@ -521,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return syserror.EISDIR
}
}
+
vfsObj := rp.VirtualFilesystem()
+ if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil {
+ return err
+ }
+
mntns := vfs.MountNamespaceFromContext(ctx)
defer mntns.DecRef(ctx)
+
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- child, ok := parent.children[name]
- if ok && child == nil {
- return syserror.ENOENT
- }
-
- sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0
- if sticky {
- if !ok {
- // If the sticky bit is set, we need to retrieve the child to determine
- // whether removing it is allowed.
- child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
- if err != nil {
- return err
- }
- } else if child != nil && !child.cachedMetadataAuthoritative() {
- // Make sure the dentry representing the file at name is up to date
- // before examining its metadata.
- child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
- if err != nil {
- return err
- }
+ // Load child if sticky bit is set because we need to determine whether
+ // deletion is allowed.
+ var child *dentry
+ if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 {
+ var ok bool
+ child, ok = parent.children[name]
+ if ok && child == nil {
+ // Hit a negative cached entry, child doesn't exist.
+ return syserror.ENOENT
+ }
+ } else {
+ child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err != nil {
+ return err
}
if err := parent.mayDelete(rp.Credentials(), child); err != nil {
return err
@@ -556,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// If a child dentry exists, prepare to delete it. This should fail if it is
// a mount point. We detect mount points by speculatively calling
- // PrepareDeleteDentry, which fails if child is a mount point. However, we
- // may need to revalidate the file in this case to make sure that it has not
- // been deleted or replaced on the remote fs, in which case the mount point
- // will have disappeared. If calling PrepareDeleteDentry fails again on the
- // up-to-date dentry, we can be sure that it is a mount point.
+ // PrepareDeleteDentry, which fails if child is a mount point.
//
// Also note that if child is nil, then it can't be a mount point.
if child != nil {
@@ -575,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
child.dirMu.Lock()
defer child.dirMu.Unlock()
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- // We can skip revalidation in several cases:
- // - We are not in InteropModeShared
- // - The parent directory is synthetic, in which case the child must also
- // be synthetic
- // - We already updated the child during the sticky bit check above
- if parent.cachedMetadataAuthoritative() || sticky {
- return err
- }
- child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
- if err != nil {
- return err
- }
- if child != nil {
- if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- return err
- }
- }
+ return err
}
}
flags := uint32(0)
@@ -723,13 +660,6 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by
- // fs.walkParentDirLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
- }
d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
return nil, err
@@ -830,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// to creating a synthetic one, i.e. one that is kept entirely in memory.
// Check that we're not overriding an existing file with a synthetic one.
- _, err = fs.stepLocked(ctx, rp, parent, true, ds)
+ _, _, err = fs.stepLocked(ctx, rp, parent, true, ds)
switch {
case err == nil:
// Step succeeded, another file exists.
@@ -891,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
defer unlock()
start := rp.Start().Impl().(*dentry)
- if !start.cachedMetadataAuthoritative() {
- // Get updated metadata for start as required by fs.stepLocked().
- if err := start.updateFromGetattr(ctx); err != nil {
- return nil, err
- }
- }
if rp.Done() {
// Reject attempts to open mount root directory with O_CREAT.
if mayCreate && rp.MustBeDir() {
@@ -905,6 +829,12 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
+ if !start.cachedMetadataAuthoritative() {
+ // Refresh dentry's attributes before opening.
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
start.IncRef()
defer start.DecRef(ctx)
unlock()
@@ -926,9 +856,12 @@ afterTrailingSymlink:
if mayCreate && rp.MustBeDir() {
return nil, syserror.EISDIR
}
+ if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil {
+ return nil, err
+ }
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
if parent.isSynthetic() {
parent.dirMu.Unlock()
@@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
- *ds = appendDentry(*ds, child)
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
@@ -1212,7 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
// Insert the dentry into the tree.
d.cacheNewChildLocked(child, name)
- *ds = appendDentry(*ds, d)
+ appendNewChildDentry(ds, d, child)
if d.cachedMetadataAuthoritative() {
d.touchCMtime()
d.dirents = nil
@@ -1297,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
+
vfsObj := rp.VirtualFilesystem()
+ if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil {
+ return err
+ }
+ if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil {
+ return err
+ }
+
// We need a dentry representing the renamed file since, if it's a
// directory, we need to check for write permission on it.
oldParent.dirMu.Lock()
defer oldParent.dirMu.Unlock()
- renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds)
+ renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
if err != nil {
return err
}
- if renamed == nil {
- return syserror.ENOENT
- }
if err := oldParent.mayDelete(creds, renamed); err != nil {
return err
}
@@ -1337,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.isDeleted() {
return syserror.ENOENT
}
- replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
- if err != nil {
+ replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil && err != syserror.ENOENT {
return err
}
var replacedVFSD *vfs.Dentry
@@ -1402,9 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// parent isn't actually changing.
if oldParent != newParent {
oldParent.decRefNoCaching()
- ds = appendDentry(ds, oldParent)
newParent.IncRef()
ds = appendDentry(ds, newParent)
+ ds = appendDentry(ds, oldParent)
if renamed.isSynthetic() {
oldParent.syntheticChildren--
newParent.syntheticChildren++
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index fb42c5f62..21692d2ac 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -32,9 +32,9 @@
// specialFileFD.mu
// specialFileFD.bufMu
//
-// Locking dentry.dirMu in multiple dentries requires that either ancestor
-// dentries are locked before descendant dentries, or that filesystem.renameMu
-// is locked for writing.
+// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that
+// either ancestor dentries are locked before descendant dentries, or that
+// filesystem.renameMu is locked for writing.
package gofer
import (
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 21b4a96fe..b0a429d42 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err
ctx.UninterruptibleSleepFinish(false)
return fdobj, err
}
+
+func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ stats, err := f.file.MultiGetAttr(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return stats, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
new file mode 100644
index 000000000..8f81f0822
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -0,0 +1,386 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+type errPartialRevalidation struct{}
+
+// Error implements error.Error.
+func (errPartialRevalidation) Error() string {
+ return "partial revalidation"
+}
+
+type errRevalidationStepDone struct{}
+
+// Error implements error.Error.
+func (errRevalidationStepDone) Error() string {
+ return "stop revalidation"
+}
+
+// revalidatePath checks cached dentries for external modification. File
+// attributes are refreshed and cache is invalidated in case the dentry has been
+// deleted, or a new file/directory created in its place.
+//
+// Revalidation stops at symlinks and mount points. The caller is responsible
+// for revalidating again after symlinks are resolved and after changing to
+// different mounts.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error {
+ // Revalidation is done even if start is synthetic in case the path is
+ // something like: ../non_synthetic_file.
+ if fs.opts.interop != InteropModeShared {
+ return nil
+ }
+
+ // Copy resolving path to walk the path for revalidation.
+ rp := rpOrig.Copy()
+ err := fs.revalidate(ctx, rp, start, rp.Done, ds)
+ rp.Release(ctx)
+ return err
+}
+
+// revalidateParentDir does the same as revalidatePath, but stops at the parent.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error {
+ // Revalidation is done even if start is synthetic in case the path is
+ // something like: ../non_synthetic_file and parent is non synthetic.
+ if fs.opts.interop != InteropModeShared {
+ return nil
+ }
+
+ // Copy resolving path to walk the path for revalidation.
+ rp := rpOrig.Copy()
+ err := fs.revalidate(ctx, rp, start, rp.Final, ds)
+ rp.Release(ctx)
+ return err
+}
+
+// revalidateOne does the same as revalidatePath, but checks a single dentry.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error {
+ // Skip revalidation for interop mode different than InteropModeShared or
+ // if the parent is synthetic (child must be synthetic too, but it cannot be
+ // replaced without first replacing the parent).
+ if parent.cachedMetadataAuthoritative() {
+ return nil
+ }
+
+ parent.dirMu.Lock()
+ child, ok := parent.children[name]
+ parent.dirMu.Unlock()
+ if !ok {
+ return nil
+ }
+
+ state := makeRevalidateState(parent)
+ defer state.release()
+
+ state.add(name, child)
+ return fs.revalidateHelper(ctx, vfsObj, state, ds)
+}
+
+// revalidate revalidates path components in rp until done returns true, or
+// until a mount point or symlink is reached. It may send multiple MultiGetAttr
+// calls to the gofer to handle ".." in the path.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * InteropModeShared is in effect.
+func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error {
+ state := makeRevalidateState(start)
+ defer state.release()
+
+ // Skip synthetic dentries because the start dentry cannot be replaced in case
+ // it has been created in the remote file system.
+ if !start.isSynthetic() {
+ state.add("", start)
+ }
+
+done:
+ for cur := start; !done(); {
+ var err error
+ cur, err = fs.revalidateStep(ctx, rp, cur, state)
+ if err != nil {
+ switch err.(type) {
+ case errPartialRevalidation:
+ if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil {
+ return err
+ }
+
+ // Reset state to release any remaining locks and restart from where
+ // stepping stopped.
+ state.reset()
+ state.start = cur
+
+ // Skip synthetic dentries because the start dentry cannot be replaced in
+ // case it has been created in the remote file system.
+ if !cur.isSynthetic() {
+ state.add("", cur)
+ }
+
+ case errRevalidationStepDone:
+ break done
+
+ default:
+ return err
+ }
+ }
+ }
+ return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds)
+}
+
+// revalidateStep walks one element of the path and updates revalidationState
+// with the dentry if needed. It may also stop the stepping or ask for a
+// partial revalidation. Partial revalidation requires the caller to revalidate
+// the current revalidationState, release all locks, and resume stepping.
+// In case a symlink is hit, revalidation stops and the caller is responsible
+// for calling revalidate again after the symlink is resolved. Revalidation may
+// also stop for other reasons, like hitting a child not in the cache.
+//
+// Returns:
+// * (dentry, nil): step worked, continue stepping.`
+// * (dentry, errPartialRevalidation): revalidation should be done with the
+// state gathered so far. Then continue stepping with the remainder of the
+// path, starting at `dentry`.
+// * (nil, errRevalidationStepDone): revalidation doesn't need to step any
+// further. It hit a symlink, a mount point, or an uncached dentry.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * !rp.Done().
+// * InteropModeShared is in effect (assumes no negative dentries).
+func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) {
+ switch name := rp.Component(); name {
+ case ".":
+ // Do nothing.
+
+ case "..":
+ // Partial revalidation is required when ".." is hit because metadata locks
+ // can only be acquired from parent to child to avoid deadlocks.
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
+ return nil, errRevalidationStepDone{}
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, errPartialRevalidation{}
+ }
+ // We must assume that d.parent is correct, because if d has been moved
+ // elsewhere in the remote filesystem so that its parent has changed,
+ // we have no way of determining its new parent's location in the
+ // filesystem.
+ //
+ // Call rp.CheckMount() before updating d.parent's metadata, since if
+ // we traverse to another mount then d.parent's metadata is irrelevant.
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, errRevalidationStepDone{}
+ }
+ rp.Advance()
+ return d.parent, errPartialRevalidation{}
+
+ default:
+ d.dirMu.Lock()
+ child, ok := d.children[name]
+ d.dirMu.Unlock()
+ if !ok {
+ // child is not cached, no need to validate any further.
+ return nil, errRevalidationStepDone{}
+ }
+
+ state.add(name, child)
+
+ // Symlink must be resolved before continuing with revalidation.
+ if child.isSymlink() {
+ return nil, errRevalidationStepDone{}
+ }
+
+ d = child
+ }
+
+ rp.Advance()
+ return d, nil
+}
+
+// revalidateHelper calls the gofer to stat all dentries in `state`. It will
+// update or invalidate dentries in the cache based on the result.
+//
+// Preconditions:
+// * fs.renameMu must be locked.
+// * InteropModeShared is in effect.
+func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error {
+ if len(state.names) == 0 {
+ return nil
+ }
+ // Lock metadata on all dentries *before* getting attributes for them.
+ state.lockAllMetadata()
+ stats, err := state.start.file.multiGetAttr(ctx, state.names)
+ if err != nil {
+ return err
+ }
+
+ i := -1
+ for d := state.popFront(); d != nil; d = state.popFront() {
+ i++
+ found := i < len(stats)
+ if i == 0 && len(state.names[0]) == 0 {
+ if found && !d.isSynthetic() {
+ // First dentry is where the search is starting, just update attributes
+ // since it cannot be replaced.
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ }
+ d.metadataMu.Unlock()
+ continue
+ }
+
+ // Note that synthetic dentries will always fails the comparison check
+ // below.
+ if !found || d.qidPath != stats[i].QID.Path {
+ d.metadataMu.Unlock()
+ if !found && d.isSynthetic() {
+ // We have a synthetic file, and no remote file has arisen to replace
+ // it.
+ return nil
+ }
+ // The file at this path has changed or no longer exists. Mark the
+ // dentry invalidated, and re-evaluate its caching status (i.e. if it
+ // has 0 references, drop it). The dentry will be reloaded next time it's
+ // accessed.
+ vfsObj.InvalidateDentry(ctx, &d.vfsd)
+
+ name := state.names[i]
+ d.parent.dirMu.Lock()
+
+ if d.isSynthetic() {
+ // Normally we don't mark invalidated dentries as deleted since
+ // they may still exist (but at a different path), and also for
+ // consistency with Linux. However, synthetic files are guaranteed
+ // to become unreachable if their dentries are invalidated, so
+ // treat their invalidation as deletion.
+ d.setDeleted()
+ d.decRefNoCaching()
+ *ds = appendDentry(*ds, d)
+
+ d.parent.syntheticChildren--
+ d.parent.dirents = nil
+ }
+
+ // Since the dirMu was released and reacquired, re-check that the
+ // parent's child with this name is still the same. Do not touch it if
+ // it has been replaced with a different one.
+ if child := d.parent.children[name]; child == d {
+ // Invalidate dentry so it gets reloaded next time it's accessed.
+ delete(d.parent.children, name)
+ }
+ d.parent.dirMu.Unlock()
+
+ return nil
+ }
+
+ // The file at this path hasn't changed. Just update cached metadata.
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ d.metadataMu.Unlock()
+ }
+
+ return nil
+}
+
+// revalidateStatePool caches revalidateState instances to save array
+// allocations for dentries and names.
+var revalidateStatePool = sync.Pool{
+ New: func() interface{} {
+ return &revalidateState{}
+ },
+}
+
+// revalidateState keeps state related to a revalidation request. It keeps track
+// of {name, dentry} list being revalidated, as well as metadata locks on the
+// dentries. The list must be in ancestry order, in other words `n` must be
+// `n-1` child.
+type revalidateState struct {
+ // start is the dentry where to start the attributes search.
+ start *dentry
+
+ // List of names of entries to refresh attributes. Names length must be the
+ // same as detries length. They are kept in separate slices because names is
+ // used to call File.MultiGetAttr().
+ names []string
+
+ // dentries is the list of dentries that correspond to the names above.
+ // dentry.metadataMu is acquired as each dentry is added to this list.
+ dentries []*dentry
+
+ // locked indicates if metadata lock has been acquired on dentries.
+ locked bool
+}
+
+func makeRevalidateState(start *dentry) *revalidateState {
+ r := revalidateStatePool.Get().(*revalidateState)
+ r.start = start
+ return r
+}
+
+// release must be called after the caller is done with this object. It releases
+// all metadata locks and resources.
+func (r *revalidateState) release() {
+ r.reset()
+ revalidateStatePool.Put(r)
+}
+
+// Preconditions:
+// * d is a descendant of all dentries in r.dentries.
+func (r *revalidateState) add(name string, d *dentry) {
+ r.names = append(r.names, name)
+ r.dentries = append(r.dentries, d)
+}
+
+func (r *revalidateState) lockAllMetadata() {
+ for _, d := range r.dentries {
+ d.metadataMu.Lock()
+ }
+ r.locked = true
+}
+
+func (r *revalidateState) popFront() *dentry {
+ if len(r.dentries) == 0 {
+ return nil
+ }
+ d := r.dentries[0]
+ r.dentries = r.dentries[1:]
+ return d
+}
+
+// reset releases all metadata locks and resets all fields to allow this
+// instance to be reused.
+func (r *revalidateState) reset() {
+ if r.locked {
+ // Unlock any remaining dentries.
+ for _, d := range r.dentries {
+ d.metadataMu.Unlock()
+ }
+ r.locked = false
+ }
+ r.start = nil
+ r.names = r.names[:0]
+ r.dentries = r.dentries[:0]
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index badca4d9f..f50b0fb08 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -612,16 +612,24 @@ afterTrailingSymlink:
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- fs.mu.RLock()
defer fs.processDeferredDecRefs(ctx)
- defer fs.mu.RUnlock()
+
+ fs.mu.RLock()
d, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
+ fs.mu.RUnlock()
return "", err
}
if !d.isSymlink() {
+ fs.mu.RUnlock()
return "", syserror.EINVAL
}
+
+ // Inode.Readlink() cannot be called holding fs locks.
+ d.IncRef()
+ defer d.DecRef(ctx)
+ fs.mu.RUnlock()
+
return d.inode.Readlink(ctx, rp.Mount())
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 16486eeae..6f699c9cd 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -534,6 +534,9 @@ func (d *Dentry) FSLocalPath() string {
// - Checking that dentries passed to methods are of the appropriate file type.
// - Checking permissions.
//
+// Inode functions may be called holding filesystem wide locks and are not
+// allowed to call vfs functions that may reenter, unless otherwise noted.
+//
// Specific responsibilities of implementations are documented below.
type Inode interface {
// Methods related to reference counting. A generic implementation is
@@ -680,6 +683,9 @@ type inodeDirectory interface {
type inodeSymlink interface {
// Readlink returns the target of a symbolic link. If an inode is not a
// symlink, the implementation should return EINVAL.
+ //
+ // Readlink is called with no kernfs locks held, so it may reenter if needed
+ // to resolve symlink targets.
Readlink(ctx context.Context, mnt *vfs.Mount) (string, error)
// Getlink returns the target of a symbolic link, as used by path
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 02bf74dbc..4718fac7a 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error)
defer file.DecRef(ctx)
root := vfs.RootFromContext(ctx)
defer root.DecRef(ctx)
+
+ // Note: it's safe to reenter kernfs from Readlink if needed to resolve path.
return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
}
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index ca8090bbf..3582d14c9 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -168,10 +168,6 @@ afterSymlink:
// Preconditions:
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
-//
-// TODO(b/166474175): Investigate all possible errors returned in this
-// function, and make sure we differentiate all errors that indicate unexpected
-// modifications to the file system from the ones that are not harmful.
func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -278,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
var buf bytes.Buffer
parent.hashMu.RLock()
_, err = merkletree.Verify(&merkletree.VerifyParams{
- Out: &buf,
- File: &fdReader,
- Tree: &fdReader,
- Size: int64(parentSize),
- Name: parent.name,
- Mode: uint32(parentStat.Mode),
- UID: parentStat.UID,
- GID: parentStat.GID,
- Children: parent.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ Children: parent.childrenNames,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
@@ -409,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
var buf bytes.Buffer
d.hashMu.RLock()
params := &merkletree.VerifyParams{
- Out: &buf,
- Tree: &fdReader,
- Size: int64(size),
- Name: d.name,
- Mode: uint32(stat.Mode),
- UID: stat.UID,
- GID: stat.GID,
- Children: d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ Children: d.childrenNames,
HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
@@ -991,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
// StatAt implements vfs.FilesystemImpl.StatAt.
-// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should
-// be verified.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 458c7fcb6..31d34ef60 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -840,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) {
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- // TODO(b/162788573): Add integrity check for metadata.
stat, err := fd.lowerFD.Stat(ctx, opts)
if err != nil {
return linux.Statx{}, err
@@ -960,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui
}
params := &merkletree.GenerateParams{
- TreeReader: &merkleReader,
- TreeWriter: &merkleWriter,
- Children: fd.d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ TreeReader: &merkleReader,
+ TreeWriter: &merkleWriter,
+ Children: fd.d.childrenNames,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
Name: fd.d.name,
Mode: uint32(stat.Mode),
@@ -1192,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
case linux.FS_IOC_GETFLAGS:
return fd.verityFlags(ctx, args[2].Pointer())
default:
- // TODO(b/169682228): Investigate which ioctl commands should
- // be allowed.
return 0, syserror.ENOSYS
}
}
@@ -1253,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
fd.d.hashMu.RLock()
n, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: dst.Writer(ctx),
- File: &dataReader,
- Tree: &merkleReader,
- Size: int64(size),
- Name: fd.d.name,
- Mode: fd.d.mode,
- UID: fd.d.uid,
- GID: fd.d.gid,
- Children: fd.d.childrenNames,
- //TODO(b/156980949): Support passing other hash algorithms.
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ Children: fd.d.childrenNames,
HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
@@ -1333,7 +1328,7 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t
func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
ts, err := fd.lowerMappable.Translate(ctx, required, optional, at)
if err != nil {
- return ts, err
+ return nil, err
}
// dataSize is the size of the whole file.
@@ -1346,17 +1341,17 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
- return ts, err
+ return nil, err
}
// The dataSize xattr should be an integer. If it's not, it indicates
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
merkleReader := FileReadWriteSeeker{
@@ -1389,7 +1384,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem
DataAndTreeInSameFile: false,
})
if err != nil {
- return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
}
return ts, err
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index d7abfefb4..f727e61b0 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -351,6 +351,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
// allocations occur.
entersyscall()
bluepill(c)
+ // The root table physical page has to be mapped to not fault in iret
+ // or sysret after switching into a user address space. sysret and
+ // iret are in the upper half that is global and already mapped.
+ switchOpts.PageTables.PrefaultRootTable()
prefaultFloatingPointState(switchOpts.FloatingPointState)
vector = c.CPU.SwitchToUser(switchOpts)
exitsyscall()
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 080859125..7ee89a735 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 0e0e82365..2029e7cf4 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -14,9 +14,11 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/context",
"//pkg/hostarch",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 45a05cd63..235b9c306 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -18,9 +18,11 @@ package control
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := binary.AlignDown(cap(buf)-len(buf), 4)
+ space := bits.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
return alignSlice(buf, align), flags
}
-func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte {
+func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte {
if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
return buf
}
@@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = putUint32(buf, msgType)
hdrBuf := buf
-
- buf = binary.Marshal(buf, hostarch.ByteOrder, data)
+ buf = append(buf, marshal.Marshal(data)...)
// If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
@@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := binary.AlignUp(len(buf), align)
+ aligned := bits.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte {
// PackTimestamp packs a SO_TIMESTAMP socket control message.
func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
+ timestampP := linux.NsecToTimeval(timestamp)
return putCmsgStruct(
buf,
linux.SOL_SOCKET,
linux.SO_TIMESTAMP,
t.Arch().Width(),
- linux.NsecToTimeval(timestamp),
+ &timestampP,
)
}
@@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
linux.SOL_TCP,
linux.TCP_INQ,
t.Arch().Width(),
- inq,
+ primitive.AllocateInt32(inq),
)
}
@@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
linux.SOL_IP,
linux.IP_TOS,
t.Arch().Width(),
- tos,
+ primitive.AllocateUint8(tos),
)
}
@@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
linux.SOL_IPV6,
linux.IPV6_TCLASS,
t.Arch().Width(),
- tClass,
+ primitive.AllocateUint32(tClass),
)
}
@@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
+ h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
return socket.ControlMessages{}, syserror.EINVAL
@@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
+ creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
scmCreds, err := NewSCMCredentials(t, creds)
if err != nil {
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.SO_TIMESTAMP:
if length < linux.SizeOfTimeval {
return socket.ControlMessages{}, syserror.EINVAL
}
var ts linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts)
+ ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
cmsgs.IP.Timestamp = ts.ToNsecCapped()
cmsgs.IP.HasTimestamp = true
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
// Unknown message type.
@@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTOS = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS)
- i += binary.AlignUp(length, width)
+ var tos primitive.Uint8
+ tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS])
+ cmsgs.IP.TOS = uint8(tos)
+ i += bits.AlignUp(length, width)
case linux.IP_PKTINFO:
if length < linux.SizeOfControlMessageIPPacketInfo {
@@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
cmsgs.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
+ packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo])
cmsgs.IP.PacketInfo = packetInfo
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IP_RECVERR:
var errCmsg linux.SockErrCMsgIPv4
@@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
cmsgs.IP.SockErr = &errCmsg
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, syserror.EINVAL
}
cmsgs.IP.HasTClass = true
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass)
- i += binary.AlignUp(length, width)
+ var tclass primitive.Uint32
+ tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass])
+ cmsgs.IP.TClass = uint32(tclass)
+ i += bits.AlignUp(length, width)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
if length < addr.SizeBytes() {
return socket.ControlMessages{}, syserror.EINVAL
}
- binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
cmsgs.IP.OriginalDstAddress = &addr
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
case linux.IPV6_RECVERR:
var errCmsg linux.SockErrCMsgIPv6
@@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
cmsgs.IP.SockErr = &errCmsg
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a5c2155a2..2e3064565 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -17,7 +17,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/hostarch",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 0d3b23643..52ae4bc9c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
@@ -529,7 +528,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.SO_TIMESTAMP:
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
- ts.UnmarshalBytes(unixCmsg.Data[:linux.SizeOfTimeval])
+ ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
controlMessages.IP.Timestamp = ts.ToNsecCapped()
}
@@ -537,17 +536,19 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.IP_TOS:
controlMessages.IP.HasTOS = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS)
+ var tos primitive.Uint8
+ tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()])
+ controlMessages.IP.TOS = uint8(tos)
case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo)
+ packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()])
controlMessages.IP.PacketInfo = packetInfo
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
controlMessages.IP.OriginalDstAddress = &addr
case unix.IP_RECVERR:
@@ -560,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass)
+ var tclass primitive.Uint32
+ tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()])
+ controlMessages.IP.TClass = uint32(tclass)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
- binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
controlMessages.IP.OriginalDstAddress = &addr
case unix.IPV6_RECVERR:
@@ -577,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
switch unixCmsg.Header.Type {
case linux.TCP_INQ:
controlMessages.IP.HasInq = true
- binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq)
+ var inq primitive.Int32
+ inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq])
+ controlMessages.IP.Inq = int32(inq)
}
}
}
@@ -691,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 {
return 0
}
- binary.Unmarshal(buf, hostarch.ByteOrder, &info)
+ info.UnmarshalUnsafe(buf[:info.SizeBytes()])
return uint32(info.State)
}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 26e8ae17a..393a1ab3a 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -15,6 +15,7 @@
package hostinet
import (
+ "encoding/binary"
"fmt"
"io"
"io/ioutil"
@@ -26,10 +27,10 @@ import (
"syscall"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -147,8 +148,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
if len(link.Data) < unix.SizeofIfInfomsg {
return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg)
}
- var ifinfo unix.IfInfomsg
- binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo)
+ var ifinfo linux.InterfaceInfoMessage
+ ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()])
inetIF := inet.Interface{
DeviceType: ifinfo.Type,
Flags: ifinfo.Flags,
@@ -178,11 +179,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
if len(addr.Data) < unix.SizeofIfAddrmsg {
return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg)
}
- var ifaddr unix.IfAddrmsg
- binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr)
+ var ifaddr linux.InterfaceAddrMessage
+ ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()])
inetAddr := inet.InterfaceAddr{
Family: ifaddr.Family,
- PrefixLen: ifaddr.Prefixlen,
+ PrefixLen: ifaddr.PrefixLen,
Flags: ifaddr.Flags,
}
attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
@@ -210,13 +211,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
continue
}
- var ifRoute unix.RtMsg
- binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute)
+ var ifRoute linux.RouteMessage
+ ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()])
inetRoute := inet.Route{
Family: ifRoute.Family,
- DstLen: ifRoute.Dst_len,
- SrcLen: ifRoute.Src_len,
- TOS: ifRoute.Tos,
+ DstLen: ifRoute.DstLen,
+ SrcLen: ifRoute.SrcLen,
+ TOS: ifRoute.TOS,
Table: ifRoute.Table,
Protocol: ifRoute.Protocol,
Scope: ifRoute.Scope,
@@ -245,7 +246,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
if len(attr.Value) != expected {
return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
}
- binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface)
+ var outputIF primitive.Int32
+ outputIF.UnmarshalUnsafe(attr.Value)
+ inetRoute.OutputInterface = int32(outputIF)
}
}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 4381dfa06..61b2c9755 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -14,14 +14,16 @@ go_library(
"tcp_matcher.go",
"udp_matcher.go",
],
+ marshal = True,
# This target depends on netstack and should only be used by epsocket,
# which is allowed to depend on netstack.
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/hostarch",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/sentry/kernel",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 4bd305a44..6fc7781ad 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
nflog("marshaling matcher %q", name)
// We have to pad this struct size to a multiple of 8 bytes.
- size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
matcher := linux.KernelXTEntryMatch{
XTEntryMatch: linux.XTEntryMatch{
MatchSize: uint16(size),
@@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte {
}
copy(matcher.Name[:], name)
- buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, hostarch.ByteOrder, matcher)
- return append(buf, make([]byte, size-len(buf))...)
+ buf := make([]byte, size)
+ entryLen := matcher.XTEntryMatch.SizeBytes()
+ matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen])
+ copy(buf[entryLen:], matcher.Data)
+ return buf
}
func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 1fc4cb651..cb78ef60b 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -18,8 +18,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
var entry linux.IPTEntry
- buf := optVal[:linux.SizeOfIPTEntry]
- binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
+ entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[linux.SizeOfIPTEntry:]
+ optVal = optVal[entry.SizeBytes():]
if entry.TargetOffset < linux.SizeOfIPTEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 67a52b628..5cb7fe4aa 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -18,8 +18,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace,
return nil, syserr.ErrInvalidArgument
}
var entry linux.IP6TEntry
- buf := optVal[:linux.SizeOfIP6TEntry]
- binary.Unmarshal(buf, hostarch.ByteOrder, &entry)
+ entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[linux.SizeOfIP6TEntry:]
+ optVal = optVal[entry.SizeBytes():]
if entry.TargetOffset < linux.SizeOfIP6TEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index c6fa3fd16..f42d73178 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -22,7 +22,6 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe
nflog("couldn't read entries: %v", err)
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
- if binary.Size(entries) > uintptr(outLen) {
+ if entries.SizeBytes() > outLen {
nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
@@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe
nflog("couldn't read entries: %v", err)
return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
- if binary.Size(entries) > uintptr(outLen) {
+ if entries.SizeBytes() > outLen {
nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
@@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
- binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace)
+ replace.UnmarshalBytes(replaceBuf)
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
@@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
}
var match linux.XTEntryMatch
- buf := optVal[:linux.SizeOfXTEntryMatch]
- binary.Unmarshal(buf, hostarch.ByteOrder, &match)
+ buf := optVal[:match.SizeBytes()]
+ match.UnmarshalUnsafe(buf)
nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
// Check some invariants.
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index b2cc6be20..60845cab3 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte {
}
}
- buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
- return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo))
+ buf := marshal.Marshal(&iptOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, buf)
}
// unmarshal implements matchMaker.unmarshal.
@@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
- binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 4ae1592b2..e94aceb92 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,11 +15,12 @@
package netfilter
import (
+ "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -189,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte {
Verdict: verdict,
}
- ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -199,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
var standardTarget linux.XTStandardTarget
- buf = buf[:linux.SizeOfXTStandardTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget)
+ standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()])
if standardTarget.Verdict < 0 {
// A Verdict < 0 indicates a non-jump verdict.
@@ -245,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte {
copy(xt.Name[:], errorName)
copy(xt.Target.Name[:], ErrorTargetName)
- ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -256,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt)
+ errTgt.UnmarshalUnsafe(buf)
// Error targets are used in 2 cases:
// * An actual error case. These rules have an error named
@@ -299,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte {
}
copy(xt.Target.Name[:], RedirectTargetName)
- ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
xt.NfRange.RangeIPV4.MinPort = htons(rt.Port)
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -320,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &rt)
+ rt.UnmarshalUnsafe(buf)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
target := redirectTarget{RedirectTarget: stack.RedirectTarget{
@@ -359,6 +356,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return &target, nil
}
+// +marshal
type nfNATTarget struct {
Target linux.XTEntryTarget
Range linux.NFNATRange
@@ -394,8 +392,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
- ret := make([]byte, 0, nfNATMarshalledSize)
- return binary.Marshal(ret, hostarch.ByteOrder, nt)
+ return marshal.Marshal(&nt)
}
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -411,7 +408,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
var natRange linux.NFNATRange
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
+ natRange.UnmarshalUnsafe(buf)
// We don't support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -468,8 +465,7 @@ func (*snatTargetMakerV4) marshal(target target) []byte {
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
- ret := make([]byte, 0, linux.SizeOfXTSNATTarget)
- return binary.Marshal(ret, hostarch.ByteOrder, xt)
+ return marshal.Marshal(&xt)
}
func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -485,7 +481,7 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
var st linux.XTSNATTarget
buf = buf[:linux.SizeOfXTSNATTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &st)
+ st.UnmarshalUnsafe(buf)
// Copy linux.XTSNATTarget to stack.SNATTarget.
target := snatTarget{SNATTarget: stack.SNATTarget{
@@ -550,8 +546,7 @@ func (*snatTargetMakerV6) marshal(target target) []byte {
nt.Range.MinProto = htons(st.Port)
nt.Range.MaxProto = nt.Range.MinProto
- ret := make([]byte, 0, nfNATMarshalledSize)
- return binary.Marshal(ret, hostarch.ByteOrder, nt)
+ return marshal.Marshal(&nt)
}
func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
@@ -567,7 +562,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
var natRange linux.NFNATRange
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
+ natRange.UnmarshalUnsafe(buf)
// TODO(gvisor.dev/issue/5689): Support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -631,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T
return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
- buf := optVal[:linux.SizeOfXTEntryTarget]
- binary.Unmarshal(buf, hostarch.ByteOrder, &target)
+ target.UnmarshalUnsafe(optVal[:target.SizeBytes()])
return unmarshalTarget(target, filter, optVal)
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 69557f515..95bb9826e 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte {
DestinationPortStart: matcher.destinationPortStart,
DestinationPortEnd: matcher.destinationPortEnd,
}
- buf := make([]byte, 0, linux.SizeOfXTTCP)
- return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp))
+ return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.XTTCP
- binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
nflog("parseMatchers: parsed XTTCP: %+v", matchData)
if matchData.Option != 0 ||
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 6a60e6bd6..fb8be27e6 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -18,8 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte {
DestinationPortStart: matcher.destinationPortStart,
DestinationPortEnd: matcher.destinationPortEnd,
}
- buf := make([]byte, 0, linux.SizeOfXTUDP)
- return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp))
+ return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp))
}
// unmarshal implements matchMaker.unmarshal.
@@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
// For alignment reasons, the match's total size may exceed what's
// strictly necessary to hold matchData.
var matchData linux.XTUDP
- binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData)
+ matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
nflog("parseMatchers: parsed XTUDP: %+v", matchData)
if matchData.InverseFlags != 0 {
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 171b95c63..64cd263da 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -14,7 +14,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
@@ -50,5 +50,7 @@ go_test(
deps = [
":netlink",
"//pkg/abi/linux",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index ab0e68af7..80385bfdc 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -19,15 +19,17 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
)
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return binary.AlignUp(length, align) - length
+ return bits.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -42,7 +44,7 @@ type Message struct {
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
hdr: hdr,
- buf: binary.Marshal(nil, hostarch.ByteOrder, hdr),
+ buf: marshal.Marshal(&hdr),
}
}
@@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
return
}
var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
// Msg portion.
totalMsgLen := int(hdr.Length)
@@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader {
// GetData unmarshals the payload message header from this netlink message, and
// returns the attributes portion.
-func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) {
b := BytesView(m.buf)
_, ok := b.Extract(linux.NetlinkMessageHeaderSize)
@@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
return nil, false
}
- size := int(binary.Size(msg))
+ size := msg.SizeBytes()
msgBytes, ok := b.Extract(size)
if !ok {
return nil, false
}
- binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg)
+ msg.UnmarshalUnsafe(msgBytes)
numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
// Linux permits the last message not being aligned, just consume all of it.
@@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) {
}
// Put serializes v into the message.
-func (m *Message) Put(v interface{}) {
- m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v)
+func (m *Message) Put(v marshal.Marshallable) {
+ m.buf = append(m.buf, marshal.Marshal(v)...)
}
// PutAttr adds v to the message as a netlink attribute.
//
// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
-// binary.Size(v) fits in math.MaxUint16 bytes.
-func (m *Message) PutAttr(atype uint16, v interface{}) {
- l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+// v.SizeBytes()) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) {
+ l := linux.NetlinkAttrHeaderSize + v.SizeBytes()
if l > math.MaxUint16 {
panic(fmt.Sprintf("attribute too large: %d", l))
}
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
m.Put(v)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
// PutAttrString adds s to the message as a netlink attribute.
func (m *Message) PutAttrString(atype uint16, s string) {
l := linux.NetlinkAttrHeaderSize + len(s) + 1
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
// String + NUL-termination.
- m.Put([]byte(s))
+ m.Put(primitive.AsByteSlice([]byte(s)))
m.putZeros(1)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest
if !ok {
return
}
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
if !ok {
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index ef13d9386..968968469 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -20,13 +20,31 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
type dummyNetlinkMsg struct {
+ marshal.StubMarshallable
Foo uint16
}
+func (*dummyNetlinkMsg) SizeBytes() int {
+ return 2
+}
+
+func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
+ p := primitive.Uint16(m.Foo)
+ p.MarshalUnsafe(dst)
+}
+
+func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
+ var p primitive.Uint16
+ p.UnmarshalUnsafe(src)
+ m.Foo = uint16(p)
+}
+
func TestParseMessage(t *testing.T) {
tests := []struct {
desc string
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 744fc74f4..c6c04b4e3 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -11,6 +11,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 5a2255db3..86f6419dc 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
Type: linux.RTM_NEWLINK,
})
- m.Put(linux.InterfaceInfoMessage{
+ m.Put(&linux.InterfaceInfoMessage{
Family: linux.AF_UNSPEC,
Type: i.DeviceType,
Index: idx,
@@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
})
m.PutAttrString(linux.IFLA_IFNAME, i.Name)
- m.PutAttr(linux.IFLA_MTU, i.MTU)
+ m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU))
mac := make([]byte, 6)
brd := mac
@@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
mac = i.Addr
brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
}
- m.PutAttr(linux.IFLA_ADDRESS, mac)
- m.PutAttr(linux.IFLA_BROADCAST, brd)
+ m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac))
+ m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd))
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl
Type: linux.RTM_NEWADDR,
})
- m.Put(linux.InterfaceAddrMessage{
+ m.Put(&linux.InterfaceAddrMessage{
Family: a.Family,
PrefixLen: a.PrefixLen,
Index: uint32(id),
})
- m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
- m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+ addr := primitive.ByteSlice([]byte(a.Addr))
+ m.PutAttr(linux.IFA_LOCAL, &addr)
+ m.PutAttr(linux.IFA_ADDRESS, &addr)
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Type: linux.RTM_NEWROUTE,
})
- m.Put(linux.RouteMessage{
+ m.Put(&linux.RouteMessage{
Family: rt.Family,
DstLen: rt.DstLen,
SrcLen: rt.SrcLen,
@@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Flags: rt.Flags,
})
- m.PutAttr(254, []byte{123})
+ m.PutAttr(254, primitive.AsByteSlice([]byte{123}))
if rt.DstLen > 0 {
- m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr))
}
if rt.SrcLen > 0 {
- m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr))
}
if rt.OutputInterface != 0 {
- m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface))
}
if len(rt.GatewayAddr) > 0 {
- m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr))
}
// TODO(gvisor.dev/issue/578): There are many more attributes.
@@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
hdr := msg.Header()
// All messages start with a 1 byte protocol family.
- var family uint8
+ var family primitive.Uint8
if _, ok := msg.GetData(&family); !ok {
// Linux ignores messages missing the protocol family. See
// net/core/rtnetlink.c:rtnetlink_rcv_msg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 30c297149..d75a2879f 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -20,7 +20,6 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa)
+ sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
@@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- sendBufferSizeP := primitive.Int32(s.sendBufferSize)
- return &sendBufferSizeP, nil
+ return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- recvBufferSizeP := primitive.Int32(math.MaxInt32)
- return &recvBufferSizeP, nil
+ return primitive.AllocateInt32(math.MaxInt32), nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// GetPeerName implements socket.Socket.GetPeerName.
@@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
// must be the kernel.
PortID: 0,
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// RecvMsg implements socket.Socket.RecvMsg.
@@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Family: linux.AF_NETLINK,
PortID: 0,
}
- fromLen := uint32(binary.Size(from))
+ fromLen := uint32(from.SizeBytes())
trunc := flags&linux.MSG_TRUNC != 0
@@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys
})
// Add the dump_done_errno payload.
- m.Put(int64(0))
+ m.Put(primitive.AllocateInt64(0))
_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
@@ -658,7 +655,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: int32(-err.ToLinux().Number()),
Header: hdr,
})
@@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: 0,
Header: hdr,
})
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 0b39a5b67..9561b7c25 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -19,7 +19,6 @@ go_library(
],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/hostarch",
"//pkg/log",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 312f5f85a..264f8d926 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "encoding/binary"
"fmt"
"io"
"io/ioutil"
@@ -35,7 +36,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
@@ -375,9 +375,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
}), nil
}
-var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
-var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
-var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
+var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
+var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
+var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -613,7 +613,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < sockAddrLinkSize {
return syserr.ErrInvalidArgument
}
- binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
+ a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
if a.Protocol != uint16(s.protocol) {
return syserr.ErrInvalidArgument
@@ -1312,7 +1312,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &v, nil
case linux.IP6T_ORIGINAL_DST:
- if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
}
@@ -1509,7 +1509,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return &v, nil
case linux.SO_ORIGINAL_DST:
- if outLen < int(binary.Size(linux.SockAddrInet{})) {
+ if outLen < sockAddrInetSize {
return nil, syserr.ErrInvalidArgument
}
@@ -1742,7 +1742,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1755,7 +1755,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1791,7 +1791,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Linger
- binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v)
+ v.UnmarshalBytes(optVal[:linux.SizeOfLinger])
+
+ if v != (linux.Linger{}) {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
ep.SocketOptions().SetLinger(tcpip.LingerOption{
Enabled: v.OnOff != 0,
@@ -2090,9 +2094,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
var (
- inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
- inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
- inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{}))
+ inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes()
+ inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes()
+ inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes()
)
// copyInMulticastRequest copies in a variable-size multicast request. The
@@ -2117,12 +2121,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
if len(optVal) >= inetMulticastRequestWithNICSize {
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req)
+ req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize])
return req, nil
}
var req linux.InetMulticastRequestWithNIC
- binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest)
+ req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize])
return req, nil
}
@@ -2132,7 +2136,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse
}
var req linux.Inet6MulticastRequest
- binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req)
+ req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize])
return req, nil
}
@@ -3101,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
continue
}
// Populate ifr.ifr_netmask (type sockaddr).
- hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
- hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET))
+ hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0)
var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
// Netmask is expected to be returned as a big endian
// value.
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 4c3d48096..9e56487a6 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -24,7 +24,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
switch family {
case unix.AF_INET:
var addr linux.SockAddrInet
- binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_INET6:
var addr linux.SockAddrInet6
- binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_UNIX:
var addr linux.SockAddrUnix
- binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
case unix.AF_NETLINK:
var addr linux.SockAddrNetlink
- binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr)
+ addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
return &addr
default:
panic(fmt.Sprintf("Unsupported socket family %v", family))
@@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrInetSize])
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrInet6Size])
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrLinkSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a)
+ a.UnmarshalUnsafe(addr[:sockAddrLinkSize])
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 2ebd77f82..1fbbd133c 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -25,7 +25,6 @@ go_library(
":strace_go_proto",
"//pkg/abi",
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bits",
"//pkg/eventchannel",
"//pkg/hostarch",
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index 71b92eaee..d66befe81 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{
433: makeSyscallInfo("fspick", FD, Path, Hex),
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
+ 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet),
}
func init() {
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index bd7361a52..1a2d7d75f 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{
433: makeSyscallInfo("fspick", FD, Path, Hex),
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
+ 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet),
}
func init() {
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index e5b7f9b96..f4aab25b0 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -20,14 +20,13 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// SocketFamily are the possible socket(2) families.
@@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{
linux.SO_TIMESTAMP: "SO_TIMESTAMP",
}
+func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights {
+ count := len(src) / linux.SizeOfControlMessageRight
+ cmr := make(linux.ControlMessageRights, count)
+ for i, _ := range cmr {
+ cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:]))
+ }
+ return cmr
+}
+
func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string {
if length > maxBytes {
return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length)
@@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var h linux.ControlMessageHeader
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h)
+ h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
var skipData bool
level := "SOL_SOCKET"
@@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
-
- numRights := rightsSize / linux.SizeOfControlMessageRight
- fds := make(linux.ControlMessageRights, numRights)
- binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds)
-
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
+ fds := unmarshalControlMessageRights(buf[i : i+rightsSize])
rights := make([]string, 0, len(fds))
for _, fd := range fds {
rights = append(rights, fmt.Sprint(fd))
@@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var creds linux.ControlMessageCredentials
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds)
+ creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
@@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
}
var tv linux.Timeval
- binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv)
+ tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
@@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += binary.AlignUp(length, width)
+ i += bits.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index e115683f8..3b4d79889 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
}
// WaitEpoll implements the epoll_wait(2) linux syscall.
-func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) {
+func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) {
// Get epoll from the file descriptor.
epollfile := t.GetFile(fd)
if epollfile == nil {
@@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve
// Try to read events and return right away if we got them or if the
// caller requested a non-blocking "wait".
r := e.ReadEvents(max)
- if len(r) != 0 || timeout == 0 {
+ if len(r) != 0 || timeoutInNanos == 0 {
return r, nil
}
@@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve
// and register with the epoll object for readability events.
var haveDeadline bool
var deadline ktime.Time
- if timeout > 0 {
- timeoutDur := time.Duration(timeout) * time.Millisecond
+ if timeoutInNanos > 0 {
+ timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond
deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
haveDeadline = true
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 2d2212605..090c5ffcb 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{
433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{
0xffffffffff600000: 96, // vsyscall gettimeofday(2)
@@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{
433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{},
Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 7f460d30b..69cbc98d0 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
@@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// EpollWait implements the epoll_wait(2) linux syscall.
-func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- epfd := args[0].Int()
- eventsAddr := args[1].Pointer()
- maxEvents := int(args[2].Int())
- timeout := int(args[3].Int())
-
- r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout)
+func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
+ r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos)
if err != nil {
return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
}
@@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
return uintptr(len(r)), nil, nil
+
+}
+
+// EpollWait implements the epoll_wait(2) linux syscall.
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ // Convert milliseconds to nanoseconds.
+ timeoutInNanos := int64(args[3].Int()) * 1000000
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
}
// EpollPwait implements the epoll_pwait(2) linux syscall.
@@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+// EpollPwait2 implements the epoll_pwait(2) linux syscall.
+func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutPtr := args[3].Pointer()
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+ haveTimeout := timeoutPtr != 0
+
+ var timeoutInNanos int64 = -1
+ if haveTimeout {
+ timeout, err := copyTimespecIn(t, timeoutPtr)
+ if err != nil {
+ return 0, nil, err
+ }
+ timeoutInNanos = timeout.ToNsec()
+
+ }
+
+ if maskAddr != 0 {
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ }
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
+}
+
// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 5e9e940df..e07917613 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -463,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(v.SizeBytes())
- if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index b980aa43e..047d955b6 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// EpollWait implements Linux syscall epoll_wait(2).
-func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- epfd := args[0].Int()
- eventsAddr := args[1].Pointer()
- maxEvents := int(args[2].Int())
- timeout := int(args[3].Int())
-
+func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
return 0, nil, syserror.EINVAL
@@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
return 0, nil, err
}
- if timeout == 0 {
+ if timeoutInNanos == 0 {
return 0, nil, nil
}
// In the first iteration of this loop, register with the epoll
@@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
defer epfile.EventUnregister(&w)
} else {
// Set up the timer if a timeout was specified.
- if timeout > 0 && !haveDeadline {
- timeoutDur := time.Duration(timeout) * time.Millisecond
+ if timeoutInNanos > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond
deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
haveDeadline = true
}
@@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
}
}
+
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutInNanos := int64(args[3].Int()) * 1000000
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
}
// EpollPwait implements Linux syscall epoll_pwait(2).
@@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// EpollPwait2 implements Linux syscall epoll_pwait(2).
+func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeoutPtr := args[3].Pointer()
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+ haveTimeout := timeoutPtr != 0
+
+ var timeoutInNanos int64 = -1
+ if haveTimeout {
+ var timeout linux.Timespec
+ if _, err := timeout.CopyIn(t, timeoutPtr); err != nil {
+ return 0, nil, err
+ }
+ timeoutInNanos = timeout.ToNsec()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 6edde0ed1..69f69e3af 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -467,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(v.SizeBytes())
- if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index c50fd97eb..0fc81e694 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -159,6 +159,7 @@ func Override() {
s.Table[327] = syscalls.Supported("preadv2", Preadv2)
s.Table[328] = syscalls.Supported("pwritev2", Pwritev2)
s.Table[332] = syscalls.Supported("statx", Statx)
+ s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2)
s.Init()
// Override ARM64.
@@ -269,6 +270,7 @@ func Override() {
s.Table[286] = syscalls.Supported("preadv2", Preadv2)
s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
s.Table[291] = syscalls.Supported("statx", Statx)
+ s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2)
s.Init()
}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index f612a71b2..176bcc242 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -524,7 +524,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St
Start: fd.vd,
})
stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, err
}
return fd.impl.Stat(ctx, opts)
@@ -539,7 +539,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.SetStat(ctx, opts)
@@ -555,7 +555,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
Start: fd.vd,
})
statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, err
}
return fd.impl.StatFS(ctx)
@@ -701,7 +701,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string
Start: fd.vd,
})
names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return names, err
}
names, err := fd.impl.ListXattr(ctx, size)
@@ -730,7 +730,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions)
Start: fd.vd,
})
val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return val, err
}
return fd.impl.GetXattr(ctx, *opts)
@@ -746,7 +746,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions)
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.SetXattr(ctx, *opts)
@@ -762,7 +762,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error {
Start: fd.vd,
})
err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
return fd.impl.RemoveXattr(ctx, name)
diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go
index 39fbac987..47848c76b 100644
--- a/pkg/sentry/vfs/opath.go
+++ b/pkg/sentry/vfs/opath.go
@@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err
Start: fd.vfsfd.vd,
})
stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, err
}
@@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) {
Start: fd.vfsfd.vd,
})
statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp)
- vfsObj.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, err
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index e4fd55012..97b898aba 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -44,13 +44,10 @@ type ResolvingPath struct {
start *Dentry
pit fspath.Iterator
- flags uint16
- mustBeDir bool // final file must be a directory?
- mustBeDirOrig bool
- symlinks uint8 // number of symlinks traversed
- symlinksOrig uint8
- curPart uint8 // index into parts
- numOrigParts uint8
+ flags uint16
+ mustBeDir bool // final file must be a directory?
+ symlinks uint8 // number of symlinks traversed
+ curPart uint8 // index into parts
creds *auth.Credentials
@@ -60,14 +57,9 @@ type ResolvingPath struct {
nextStart *Dentry // ref held if not nil
absSymlinkTarget fspath.Path
- // ResolvingPath must track up to two relative paths: the "current"
- // relative path, which is updated whenever a relative symlink is
- // encountered, and the "original" relative path, which is updated from the
- // current relative path by handleError() when resolution must change
- // filesystems (due to reaching a mount boundary or absolute symlink) and
- // overwrites the current relative path when Restart() is called.
- parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
- origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
+ // ResolvingPath tracks relative paths, which is updated whenever a relative
+ // symlink is encountered.
+ parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator
}
const (
@@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{
},
}
+// getResolvingPath gets a new ResolvingPath from the pool. Caller must call
+// ResolvingPath.Release() when done.
func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath {
rp := resolvingPathPool.Get().(*ResolvingPath)
rp.vfs = vfs
@@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat
rp.flags |= rpflagsFollowFinalSymlink
}
rp.mustBeDir = pop.Path.Dir
- rp.mustBeDirOrig = pop.Path.Dir
rp.symlinks = 0
rp.curPart = 0
- rp.numOrigParts = 1
rp.creds = creds
rp.parts[0] = pop.Path.Begin
- rp.origParts[0] = pop.Path.Begin
return rp
}
-func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) {
+// Copy creates another ResolvingPath with the same state as the original.
+// Copies are independent, using the copy does not change the original and
+// vice-versa.
+//
+// Caller must call Resease() when done.
+func (rp *ResolvingPath) Copy() *ResolvingPath {
+ copy := resolvingPathPool.Get().(*ResolvingPath)
+ *copy = *rp // All fields all shallow copiable.
+
+ // Take extra reference for the copy if the original had them.
+ if copy.flags&rpflagsHaveStartRef != 0 {
+ copy.start.IncRef()
+ }
+ if copy.flags&rpflagsHaveMountRef != 0 {
+ copy.mount.IncRef()
+ }
+ // Reset error state.
+ copy.nextStart = nil
+ copy.nextMount = nil
+ return copy
+}
+
+// Release decrements references if needed and returns the object to the pool.
+func (rp *ResolvingPath) Release(ctx context.Context) {
rp.root = VirtualDentry{}
rp.decRefStartAndMount(ctx)
rp.mount = nil
@@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() {
}
}
-// Restart resets the stream of path components represented by rp to its state
-// on entry to the current FilesystemImpl method.
-func (rp *ResolvingPath) Restart(ctx context.Context) {
- rp.pit = rp.origParts[rp.numOrigParts-1]
- rp.mustBeDir = rp.mustBeDirOrig
- rp.symlinks = rp.symlinksOrig
- rp.curPart = rp.numOrigParts - 1
- copy(rp.parts[:], rp.origParts[:rp.numOrigParts])
- rp.releaseErrorState(ctx)
-}
-
-func (rp *ResolvingPath) relpathCommit() {
- rp.mustBeDirOrig = rp.mustBeDir
- rp.symlinksOrig = rp.symlinks
- rp.numOrigParts = rp.curPart + 1
- copy(rp.origParts[:rp.curPart], rp.parts[:])
- rp.origParts[rp.curPart] = rp.pit
-}
-
// CheckRoot is called before resolving the parent of the Dentry d. If the
// Dentry is contextually a VFS root, such that path resolution should treat
// d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the
@@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef
rp.nextMount = nil
rp.nextStart = nil
- // Commit the previous FileystemImpl's progress through the relative
- // path. (Don't consume the path component that caused us to traverse
+ // Don't consume the path component that caused us to traverse
// through the mount root - i.e. the ".." - because we still need to
- // resolve the mount point's parent in the new FilesystemImpl.)
- rp.relpathCommit()
+ // resolve the mount point's parent in the new FilesystemImpl.
+ //
// Restart path resolution on the new Mount. Don't bother calling
// rp.releaseErrorState() since we already set nextMount and nextStart
// to nil above.
@@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.nextMount = nil
// Consume the path component that represented the mount point.
rp.Advance()
- // Commit the previous FilesystemImpl's progress through the relative
- // path.
- rp.relpathCommit()
// Restart path resolution on the new Mount.
rp.releaseErrorState(ctx)
return true
@@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
rp.Advance()
// Prepend the symlink target to the relative path.
rp.relpathPrepend(rp.absSymlinkTarget)
- // Commit the previous FilesystemImpl's progress through the relative
- // path, including the symlink target we just prepended.
- rp.relpathCommit()
// Restart path resolution on the new Mount.
rp.releaseErrorState(ctx)
return true
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 00f1847d8..87fdcf403 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede
dentry: d,
}
rp.mount.IncRef()
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return vd, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return VirtualDentry{}, err
}
}
@@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
rp.mount.IncRef()
name := rp.Component()
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return parentVD, name, nil
}
if checkInvariants {
@@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return VirtualDentry{}, "", err
}
}
@@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
for {
err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldVD.DecRef(ctx)
return nil
}
@@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldVD.DecRef(ctx)
return err
}
@@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
rp := vfs.getResolvingPath(creds, pop)
if opts.Flags&linux.O_DIRECTORY != 0 {
rp.mustBeDir = true
- rp.mustBeDirOrig = true
}
// Ignore O_PATH for verity, as verity performs extra operations on the fd for verification.
// The underlying filesystem that verity wraps opens the fd with O_PATH.
@@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
for {
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
if opts.FileExec {
if fd.Mount().Flags.NoExec {
@@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
return fd, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden
for {
target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return target, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return "", err
}
}
@@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldParentVD.DecRef(ctx)
return nil
}
@@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
oldParentVD.DecRef(ctx)
return err
}
@@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
for {
err := rp.mount.fs.impl.RmdirAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential
for {
stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return stat, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return linux.Statx{}, err
}
}
@@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti
for {
statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return statfs, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return linux.Statfs{}, err
}
}
@@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
for {
err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
for {
err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if checkInvariants {
@@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
for {
bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return bep, nil
}
if checkInvariants {
@@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C
}
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede
for {
names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return names, nil
}
if err == syserror.ENOTSUP {
@@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede
// fs/xattr.c:vfs_listxattr() falls back to allowing the security
// subsystem to return security extended attributes, which by
// default don't exist.
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil, err
}
}
@@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden
for {
val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return val, nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return "", err
}
}
@@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden
for {
err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
@@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre
for {
err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name)
if err == nil {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return nil
}
if !rp.handleError(ctx, err) {
- vfs.putResolvingPath(ctx, rp)
+ rp.Release(ctx)
return err
}
}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index aa30cfc85..e96ba50ae 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -22,8 +22,9 @@ go_library(
"errors.go",
"sock_err_list.go",
"socketops.go",
+ "stdclock.go",
+ "stdclock_state.go",
"tcpip.go",
- "time_unsafe.go",
"timer.go",
],
visibility = ["//visibility:public"],
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f75ee34ab..ef9126deb 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -123,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) {
q.notify = notify
}
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -130,6 +133,7 @@ type Endpoint struct {
mtu uint32
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
+ SupportedGSOKind stack.SupportedGSO
// Outbound packet queue.
q *queue
@@ -211,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.LinkEPCapabilities
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (*Endpoint) GSOMaxSize() uint32 {
return 1 << 15
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.SupportedGSOKind
+}
+
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index f042df82e..d971194e6 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/binary",
"//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index feb79fe0e..bddb1d0a2 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,7 +45,6 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string {
}
}
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -134,6 +136,9 @@ type endpoint struct {
// wg keeps track of running goroutines.
wg sync.WaitGroup
+
+ // gsoKind is the supported kind of GSO.
+ gsoKind stack.SupportedGSO
}
// Options specify the details about the fd-based endpoint to be created.
@@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
if isSocket {
if opts.GSOMaxSize != 0 {
if opts.SoftwareGSOEnabled {
- e.caps |= stack.CapabilitySoftwareGSO
+ e.gsoKind = stack.SWGSOSupported
} else {
- e.caps |= stack.CapabilityHardwareGSO
+ e.gsoKind = stack.HWGSOSupported
}
e.gsoMaxSize = opts.GSOMaxSize
}
@@ -403,6 +408,35 @@ type virtioNetHdr struct {
csumOffset uint16
}
+// marshal serializes h to a newly-allocated byte slice, in little-endian byte
+// order.
+//
+// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used
+// for general serialization. This makes it difficult to use go-marshal for
+// virtio types, as go-marshal implicitly uses the native byte ordering.
+func (h *virtioNetHdr) marshal() []byte {
+ buf := [virtioNetHdrSize]byte{
+ 0: byte(h.flags),
+ 1: byte(h.gsoType),
+
+ // Manually lay out the fields in little-endian byte order. Little endian =>
+ // least significant bit goes to the lower address.
+
+ 2: byte(h.hdrLen),
+ 3: byte(h.hdrLen >> 8),
+
+ 4: byte(h.gsoSize),
+ 5: byte(h.gsoSize >> 8),
+
+ 6: byte(h.csumStart),
+ 7: byte(h.csumStart >> 8),
+
+ 8: byte(h.csumOffset),
+ 9: byte(h.csumOffset >> 8),
+ }
+ return buf[:]
+}
+
// These constants are declared in linux/virtio_net.h.
const (
_VIRTIO_NET_HDR_F_NEEDS_CSUM = 1
@@ -441,7 +475,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
@@ -463,7 +497,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
}
}
- vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf := vnetHdr.marshal()
builder.Add(vnetHdrBuf)
}
@@ -482,7 +516,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
}
var vnetHdrBuf []byte
- if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
vnetHdr.hdrLen = uint16(pkt.HeaderSize())
@@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
- vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ vnetHdrBuf = vnetHdr.marshal()
}
var builder iovec.Builder
@@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
}
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// SupportsHWGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ return e.gsoKind
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
if e.hdrSize > 0 {
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index a7adf822b..4b7ef3aac 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -128,7 +128,7 @@ type readVDispatcher struct {
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
}
@@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
}
- skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
for i := range d.bufs {
d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
}
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 89df35822..3e816b0c7 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *Endpoint) SupportedGSO() stack.SupportedGSO {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index bba6a6973..b1a28491d 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -25,6 +25,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+
// endpoint represents a LinkEndpoint which implements a FIFO queue for all
// outgoing packets. endpoint can have 1 or more underlying queueDispatchers.
// All outgoing packets are consistenly hashed to a single underlying queue
@@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.lower.LinkAddress()
}
-// GSOMaxSize returns the maximum GSO packet size.
+// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
if gso, ok := e.lower.(stack.GSOEndpoint); ok {
return gso.GSOMaxSize()
@@ -149,6 +152,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// SupportedGSO implements stack.GSOEndpoint.
+func (e *endpoint) SupportedGSO() stack.SupportedGSO {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.SupportedGSO()
+ }
+ return stack.GSONotSupported
+}
+
// WritePacket implements stack.LinkEndpoint.WritePacket.
//
// The packet must have the following fields populated:
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index ac35d81e7..d22974b12 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ip holds IPv4/IPv6 common utilities.
package ip
import (
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index f663fdc0b..3f2093f00 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
return
}
- // Skip the ip header, then deliver the error.
- pkt.Data().TrimFront(hlen)
+ // Keep needed information before trimming header.
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
+ dstAddr := hdr.DestinationAddress()
+ // Skip the ip header, then deliver the error.
+ pkt.Data().DeleteFront(hlen)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
@@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
- pkt.Data().TrimFront(header.ICMPv4MinimumSize)
- switch h.Code() {
+ mtu := h.MTU()
+ code := h.Code()
+ pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
- networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 1319db32b..28bb61a08 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
return
}
+ // Keep needed information before trimming header.
+ p := hdr.TransportProtocol()
+ dstAddr := hdr.DestinationAddress()
+
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().TrimFront(header.IPv6MinimumSize)
- p := hdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6MinimumSize)
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// because they don't have the transport headers.
return
}
+ p = fragHdr.TransportProtocol()
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().TrimFront(header.IPv6FragmentHeaderSize)
- p = fragHdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
}
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize)
networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
+ pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
@@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch header.ICMPv6(hdr).Code() {
+ code := header.ICMPv6(hdr).Code()
+ pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 2d74e0abc..7d3725681 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: vv.ToView().ToVectorisedView(),
})
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
_ = r.WriteHeaderIncludedPacket(pkt)
}
@@ -264,6 +264,8 @@ type fwdTestPacketInfo struct {
Pkt *PacketBuffer
}
+var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)
+
type fwdTestLinkEndpoint struct {
dispatcher NetworkDispatcher
mtu uint32
@@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
return caps | CapabilityResolutionRequired
}
-// GSOMaxSize returns the maximum GSO packet size.
-func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
- return 1 << 15
-}
-
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 48bb75e2f..9821a18d3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
func BenchmarkCacheClear(b *testing.B) {
b.StopTimer()
config := DefaultNUDConfigurations()
- clock := &tcpip.StdClock{}
+ clock := tcpip.NewStdClock()
linkRes := newTestNeighborResolver(nil, config, clock)
linkRes.delay = 0
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 646979d1e..9527416cf 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -364,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) {
return d.pk.data.PullUp(size)
}
-// TrimFront removes count from the beginning of d. It panics if count >
-// d.Size().
-func (d PacketData) TrimFront(count int) {
+// DeleteFront removes count from the beginning of d. It panics if count >
+// d.Size(). All backing storage references after the front of the d are
+// invalidated.
+func (d PacketData) DeleteFront(count int) {
d.pk.data.TrimFront(count)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 6728370c3..bd4eb4fed 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, test.data)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
- concatViews(test.link, test.network, test.transport, test.data))
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(test.link, test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(test.transport, test.data))
+ // Check the after state.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.link,
+ network: test.network,
+ transport: test.transport,
+ data: test.data,
+ })
})
}
}
@@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) {
if got, want := pk.Size(), len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- // After state of pk.
- var (
- link = test.data[:test.link]
- network = test.data[test.link:][:test.network]
- transport = test.data[test.link+test.network:][:test.transport]
- payload = test.data[allHdrSize:]
- )
- checkData(t, pk, payload)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(link, network, transport, payload))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(network, transport, payload))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(transport, payload))
+ // Check the after state of pk.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.data[:test.link],
+ network: test.data[test.link:][:test.network],
+ transport: test.data[test.link+test.network:][:test.transport],
+ data: test.data[allHdrSize:],
+ })
})
}
}
@@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
})
}
+// This is a very obscure use-case seen in the code that verifies packets
+// before sending them out. It tries to parse the headers to verify.
+// PacketHeader was initially not designed to mix Push() and Consume(), but it
+// works and it's been relied upon. Include a test here.
+func TestPacketHeaderPushConsumeMixed(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := append([]byte(nil), network...)
+ initData = append(initData, data...)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Consume network header
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(network))
+ if !ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network))
+ }
+ checkViewEqual(t, "gotNetwork", gotNetwork, network)
+
+ // 2. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ network: network,
+ data: data,
+ })
+}
+
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
@@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // TrimFront
+ // DeleteFront
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().TrimFront(n)
+ pkt.Data().DeleteFront(n)
checkData(t, pkt, []byte(tc.data)[n:])
})
@@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) {
}
}
+type packetContents struct {
+ link buffer.View
+ network buffer.View
+ transport buffer.View
+ data buffer.View
+}
+
+func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) {
+ t.Helper()
+ // Headers.
+ checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link)
+ checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network)
+ checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport)
+ // Data.
+ checkData(t, pk, want.data)
+ // Whole packet.
+ checkViewEqual(t, prefix+"pk.Views()",
+ concatViews(pk.Views()...),
+ concatViews(want.link, want.network, want.transport, want.data))
+ // PayloadSince.
+ checkViewEqual(t, prefix+"PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(want.link, want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(want.transport, want.data))
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, data)
- checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
- // Check the initial values for each header.
- checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
- checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
- checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
- // Check the initial valies for PayloadSince.
- checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()), data)
+ checkPacketContents(t, "Initial ", pk, packetContents{
+ data: data,
+ })
}
func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 7ad206f6d..e26225552 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -756,11 +756,6 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityHardwareGSO
-
- // CapabilitySoftwareGSO indicates the link endpoint supports of sending
- // multiple packets using a single call (LinkEndpoint.WritePackets).
- CapabilitySoftwareGSO
)
// NetworkLinkEndpoint is a data-link layer that supports sending network
@@ -1047,10 +1042,29 @@ type GSO struct {
MaxSize uint32
}
+// SupportedGSO returns the type of segmentation offloading supported.
+type SupportedGSO int
+
+const (
+ // GSONotSupported indicates that segmentation offloading is not supported.
+ GSONotSupported SupportedGSO = iota
+
+ // HWGSOSupported indicates that segmentation offloading may be performed by
+ // the hardware.
+ HWGSOSupported
+
+ // SWGSOSupported indicates that segmentation offloading may be performed in
+ // software.
+ SWGSOSupported
+)
+
// GSOEndpoint provides access to GSO properties.
type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
+
+ // SupportedGSO returns the supported segmentation offloading.
+ SupportedGSO() SupportedGSO
}
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 4ecde5995..8a044c073 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool {
// HasSoftwareGSOCapability returns true if the route supports software GSO.
func (r *Route) HasSoftwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == SWGSOSupported
+ }
+ return false
}
// HasHardwareGSOCapability returns true if the route supports hardware GSO.
func (r *Route) HasHardwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == HWGSOSupported
+ }
+ return false
}
// HasSaveRestoreCapability returns true if the route supports save/restore.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 843118b13..436392f23 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -322,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
- clock = &tcpip.StdClock{}
+ clock = tcpip.NewStdClock()
}
if opts.UniqueID == nil {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 8ead3b8df..4fe9df999 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -138,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data().TrimFront(fakeNetHeaderLen)
+ // DeleteFront invalidates slices. Make a copy before trimming.
+ nb := append([]byte(nil), hdr...)
+ pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go
new file mode 100644
index 000000000..7ce43a68e
--- /dev/null
+++ b/pkg/tcpip/stdclock.go
@@ -0,0 +1,130 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// stdClock implements Clock with the time package.
+//
+// +stateify savable
+type stdClock struct {
+ // baseTime holds the time when the clock was constructed.
+ //
+ // This value is used to calculate the monotonic time from the time package.
+ // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks,
+ //
+ // Operating systems provide both a “wall clock,” which is subject to
+ // changes for clock synchronization, and a “monotonic clock,” which is not.
+ // The general rule is that the wall clock is for telling time and the
+ // monotonic clock is for measuring time. Rather than split the API, in this
+ // package the Time returned by time.Now contains both a wall clock reading
+ // and a monotonic clock reading; later time-telling operations use the wall
+ // clock reading, but later time-measuring operations, specifically
+ // comparisons and subtractions, use the monotonic clock reading.
+ //
+ // ...
+ //
+ // If Times t and u both contain monotonic clock readings, the operations
+ // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using
+ // the monotonic clock readings alone, ignoring the wall clock readings. If
+ // either t or u contains no monotonic clock reading, these operations fall
+ // back to using the wall clock readings.
+ //
+ // Given the above, we can safely conclude that time.Since(baseTime) will
+ // return monotonically increasing values if we use time.Now() to set baseTime
+ // at the time of clock construction.
+ //
+ // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per
+ // https://golang.org/pkg/time/#Since.
+ baseTime time.Time `state:"nosave"`
+
+ // monotonicOffset is the offset applied to the calculated monotonic time.
+ //
+ // monotonicOffset is assigned maxMonotonic after restore so that the
+ // monotonic time will continue from where it "left off" before saving as part
+ // of S/R.
+ monotonicOffset int64 `state:"nosave"`
+
+ // monotonicMU protects maxMonotonic.
+ monotonicMU sync.Mutex `state:"nosave"`
+ maxMonotonic int64
+}
+
+// NewStdClock returns an instance of a clock that uses the time package.
+func NewStdClock() Clock {
+ return &stdClock{
+ baseTime: time.Now(),
+ }
+}
+
+var _ Clock = (*stdClock)(nil)
+
+// NowNanoseconds implements Clock.NowNanoseconds.
+func (*stdClock) NowNanoseconds() int64 {
+ return time.Now().UnixNano()
+}
+
+// NowMonotonic implements Clock.NowMonotonic.
+func (s *stdClock) NowMonotonic() int64 {
+ sinceBase := time.Since(s.baseTime)
+ if sinceBase < 0 {
+ panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime))
+ }
+
+ monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset
+
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+
+ // Monotonic time values must never decrease.
+ if monotonicValue > s.maxMonotonic {
+ s.maxMonotonic = monotonicValue
+ }
+
+ return s.maxMonotonic
+}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*stdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/stdclock_state.go b/pkg/tcpip/stdclock_state.go
new file mode 100644
index 000000000..795db9181
--- /dev/null
+++ b/pkg/tcpip/stdclock_state.go
@@ -0,0 +1,26 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import "time"
+
+// afterLoad is invoked by stateify.
+func (s *stdClock) afterLoad() {
+ s.baseTime = time.Now()
+
+ s.monotonicMU.Lock()
+ defer s.monotonicMU.Unlock()
+ s.monotonicOffset = s.maxMonotonic
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0ba71b62e..d8a10065d 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -73,7 +73,7 @@ type Clock interface {
// nanoseconds since the Unix epoch.
NowNanoseconds() int64
- // NowMonotonic returns a monotonic time value.
+ // NowMonotonic returns a monotonic time value at nanosecond resolution.
NowMonotonic() int64
// AfterFunc waits for the duration to elapse and then calls f in its own
@@ -1107,6 +1107,7 @@ const (
// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
// duration for which a socket lingers before returning from Close.
//
+// +marshal
// +stateify savable
type LingerOption struct {
Enabled bool
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
deleted file mode 100644
index eeea97b12..000000000
--- a/pkg/tcpip/time_unsafe.go
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build go1.9
-// +build !go1.18
-
-// Check go:linkname function signatures when updating Go version.
-
-package tcpip
-
-import (
- "time" // Used with go:linkname.
- _ "unsafe" // Required for go:linkname.
-)
-
-// StdClock implements Clock with the time package.
-//
-// +stateify savable
-type StdClock struct{}
-
-var _ Clock = (*StdClock)(nil)
-
-//go:linkname now time.now
-func now() (sec int64, nsec int32, mono int64)
-
-// NowNanoseconds implements Clock.NowNanoseconds.
-func (*StdClock) NowNanoseconds() int64 {
- sec, nsec, _ := now()
- return sec*1e9 + int64(nsec)
-}
-
-// NowMonotonic implements Clock.NowMonotonic.
-func (*StdClock) NowMonotonic() int64 {
- _, _, mono := now()
- return mono
-}
-
-// AfterFunc implements Clock.AfterFunc.
-func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
- return &stdTimer{
- t: time.AfterFunc(d, f),
- }
-}
-
-type stdTimer struct {
- t *time.Timer
-}
-
-var _ Timer = (*stdTimer)(nil)
-
-// Stop implements Timer.Stop.
-func (st *stdTimer) Stop() bool {
- return st.t.Stop()
-}
-
-// Reset implements Timer.Reset.
-func (st *stdTimer) Reset(d time.Duration) {
- st.t.Reset(d)
-}
-
-// NewStdTimer returns a Timer implemented with the time package.
-func NewStdTimer(t *time.Timer) Timer {
- return &stdTimer{t: t}
-}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index a82384c49..1633d0aeb 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -29,7 +29,7 @@ const (
)
func TestJobReschedule(t *testing.T) {
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
wg.Done()
})
job.Schedule(shortDuration)
@@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) {
func TestJobExecution(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() {
+ job := tcpip.NewJob(clock, &lock, func() {
ch <- struct{}{}
})
job.Schedule(shortDuration)
@@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(middleDuration)
lock.Lock()
@@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) {
func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) {
func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
job.Cancel()
lock.Unlock()
@@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
@@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- var clock tcpip.StdClock
+ clock := tcpip.NewStdClock()
var lock sync.Mutex
ch := make(chan struct{})
lock.Lock()
- job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
job.Cancel()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9948f305b..8afde7fca 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -747,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -756,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(b/129292233): Determine if len(h) check is still needed after early
- // parsing.
+ // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
+ // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 16f8c5212..53efecc5a 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -1214,9 +1214,9 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
if enable {
- c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.HWGSOSupported
} else {
- c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ c.linkEP.SupportedGSOKind = stack.GSONotSupported
}
}