summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/file.go4
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/crypto/crypto_stdlib.go15
-rw-r--r--pkg/lisafs/BUILD117
-rw-r--r--pkg/lisafs/README.md3
-rw-r--r--pkg/lisafs/channel.go190
-rw-r--r--pkg/lisafs/client.go432
-rw-r--r--pkg/lisafs/client_file.go475
-rw-r--r--pkg/lisafs/communicator.go80
-rw-r--r--pkg/lisafs/connection.go320
-rw-r--r--pkg/lisafs/connection_test.go194
-rw-r--r--pkg/lisafs/fd.go374
-rw-r--r--pkg/lisafs/handlers.go768
-rw-r--r--pkg/lisafs/lisafs.go18
-rw-r--r--pkg/lisafs/message.go1251
-rw-r--r--pkg/lisafs/sample_message.go110
-rw-r--r--pkg/lisafs/server.go113
-rw-r--r--pkg/lisafs/sock.go208
-rw-r--r--pkg/lisafs/sock_test.go217
-rw-r--r--pkg/lisafs/testsuite/BUILD20
-rw-r--r--pkg/lisafs/testsuite/testsuite.go637
-rw-r--r--pkg/p9/client.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD3
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go101
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go490
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go681
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go1
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go80
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go50
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go143
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go45
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go18
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go5
-rw-r--r--pkg/sentry/seccheck/BUILD4
-rw-r--r--pkg/sentry/seccheck/execve.go65
-rw-r--r--pkg/sentry/seccheck/exit.go57
-rw-r--r--pkg/sentry/seccheck/seccheck.go26
-rw-r--r--pkg/sentry/socket/control/control.go19
-rw-r--r--pkg/sentry/socket/netstack/netstack.go20
-rw-r--r--pkg/sentry/socket/socket.go26
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go8
-rw-r--r--pkg/sentry/time/sampler_arm64.go4
-rw-r--r--pkg/sentry/vfs/resolving_path.go6
-rw-r--r--pkg/tcpip/checker/checker.go13
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go89
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go55
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go136
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go77
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go55
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go151
-rw-r--r--pkg/tcpip/socketops.go18
-rw-r--r--pkg/tcpip/stack/conntrack.go110
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go46
-rw-r--r--pkg/tcpip/stack/packet_buffer.go40
-rw-r--r--pkg/tcpip/tcpip.go14
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go288
-rw-r--r--pkg/tcpip/tests/utils/utils.go36
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go50
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go6
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go106
-rw-r--r--pkg/tcpip/transport/tcp/connect.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go24
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go75
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go99
76 files changed, 8317 insertions, 684 deletions
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 1e23850a9..67646f837 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -242,7 +242,7 @@ const (
// Statx represents struct statx.
//
-// +marshal
+// +marshal slice:StatxSlice
type Statx struct {
Mask uint32
Blksize uint32
@@ -270,6 +270,8 @@ type Statx struct {
var SizeOfStatx = (*Statx)(nil).SizeBytes()
// FileMode represents a mode_t.
+//
+// +marshal
type FileMode uint16
// Permissions returns just the permission bits.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 95871b8a5..f60e42997 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -542,6 +542,15 @@ type ControlMessageIPPacketInfo struct {
DestinationAddr InetAddr
}
+// ControlMessageIPv6PacketInfo represents struct in6_pktinfo from linux/ipv6.h.
+//
+// +marshal
+// +stateify savable
+type ControlMessageIPv6PacketInfo struct {
+ Addr Inet6Addr
+ NIC uint32
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes()
@@ -566,6 +575,10 @@ const SizeOfControlMessageTClass = 4
// control message.
const SizeOfControlMessageIPPacketInfo = 12
+// SizeOfControlMessageIPv6PacketInfo is the size of a
+// ControlMessageIPv6PacketInfo.
+const SizeOfControlMessageIPv6PacketInfo = 20
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
index 69e867386..28eba2ff6 100644
--- a/pkg/crypto/crypto_stdlib.go
+++ b/pkg/crypto/crypto_stdlib.go
@@ -19,14 +19,21 @@ package crypto
import (
"crypto/ecdsa"
+ "crypto/elliptic"
"crypto/sha512"
+ "fmt"
"math/big"
)
-// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
-// public key, pub. Its return value records whether the signature is valid.
-func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) {
- return ecdsa.Verify(pub, hash, r, s), nil
+// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA
+// P384 + SHA 384 and the public key, pub. Its return value records whether
+// the signature is valid.
+func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) {
+ if pub.Curve != elliptic.P384() {
+ return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve)
+ }
+ digest := sha512.Sum384(data)
+ return ecdsa.Verify(pub, digest[:], r, s), nil
}
// SumSha384 returns the SHA384 checksum of the data.
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD
new file mode 100644
index 000000000..313c1756d
--- /dev/null
+++ b/pkg/lisafs/BUILD
@@ -0,0 +1,117 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "control_fd_refs",
+ out = "control_fd_refs.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_refs",
+ out = "open_fd_refs.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "OpenFD",
+ },
+)
+
+go_template_instance(
+ name = "control_fd_list",
+ out = "control_fd_list.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ControlFD",
+ "Linker": "*ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_list",
+ out = "open_fd_list.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*OpenFD",
+ "Linker": "*OpenFD",
+ },
+)
+
+go_library(
+ name = "lisafs",
+ srcs = [
+ "channel.go",
+ "client.go",
+ "client_file.go",
+ "communicator.go",
+ "connection.go",
+ "control_fd_list.go",
+ "control_fd_refs.go",
+ "fd.go",
+ "handlers.go",
+ "lisafs.go",
+ "message.go",
+ "open_fd_list.go",
+ "open_fd_refs.go",
+ "sample_message.go",
+ "server.go",
+ "sock.go",
+ ],
+ marshal = True,
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/context",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/fspath",
+ "//pkg/hostarch",
+ "//pkg/log",
+ "//pkg/marshal/primitive",
+ "//pkg/p9",
+ "//pkg/refsvfs2",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "sock_test",
+ size = "small",
+ srcs = ["sock_test.go"],
+ library = ":lisafs",
+ deps = [
+ "//pkg/marshal",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "connection_test",
+ size = "small",
+ srcs = ["connection_test.go"],
+ deps = [
+ ":lisafs",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
index 51d0d40e5..6b857321a 100644
--- a/pkg/lisafs/README.md
+++ b/pkg/lisafs/README.md
@@ -1,5 +1,8 @@
# Replacing 9P
+NOTE: LISAFS is **NOT** production ready. There are still some security concerns
+that must be resolved first.
+
## Background
The Linux filesystem model consists of the following key aspects (modulo mounts,
diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go
new file mode 100644
index 000000000..301212e51
--- /dev/null
+++ b/pkg/lisafs/channel.go
@@ -0,0 +1,190 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "runtime"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes())
+)
+
+// maxChannels returns the number of channels a client can create.
+//
+// The server will reject channel creation requests beyond this (per client).
+// Note that we don't want the number of channels to be too large, because each
+// accounts for a large region of shared memory.
+// TODO(gvisor.dev/issue/6313): Tune the number of channels.
+func maxChannels() int {
+ maxChans := runtime.GOMAXPROCS(0)
+ if maxChans < 2 {
+ maxChans = 2
+ }
+ if maxChans > 4 {
+ maxChans = 4
+ }
+ return maxChans
+}
+
+// channel implements Communicator and represents the communication endpoint
+// for the client and server and is used to perform fast IPC. Apart from
+// communicating data, a channel is also capable of donating file descriptors.
+type channel struct {
+ fdTracker
+ dead bool
+ data flipcall.Endpoint
+ fdChan fdchannel.Endpoint
+}
+
+var _ Communicator = (*channel)(nil)
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (ch *channel) PayloadBuf(size uint32) []byte {
+ return ch.data.Data()[chanHeaderLen : chanHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ // Write header. Requests can not donate FDs.
+ ch.marshalHdr(m, 0 /* numFDs */)
+
+ // One-shot communication. RPCs are expected to be quick rather than block.
+ rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen)
+ if err != nil {
+ // This channel is now unusable.
+ ch.dead = true
+ // Map the transport errors to EIO, but also log the real error.
+ log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err)
+ return 0, 0, unix.EIO
+ }
+
+ return ch.rcvMsg(rcvDataLen)
+}
+
+func (ch *channel) shutdown() {
+ ch.data.Shutdown()
+}
+
+func (ch *channel) destroy() {
+ ch.dead = true
+ ch.fdChan.Destroy()
+ ch.data.Destroy()
+}
+
+// createChannel creates a server side channel. It returns a packet window
+// descriptor (for the data channel) and an open socket for the FD channel.
+func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ // If c.channels is nil, the connection has closed.
+ if c.channels == nil || len(c.channels) >= maxChannels() {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS
+ }
+ ch := &channel{}
+
+ // Set up data channel.
+ desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize))
+ if err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ if err := ch.data.Init(flipcall.ServerSide, desc); err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+
+ // Set up FD channel.
+ fdSocks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ ch.data.Destroy()
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ ch.fdChan.Init(fdSocks[0])
+ clientFDSock := fdSocks[1]
+
+ c.channels = append(c.channels, ch)
+ return ch, desc, clientFDSock, nil
+}
+
+// sendFDs sends as many FDs as it can. The failure to send an FD does not
+// cause an error and fail the entire RPC. FDs are considered supplementary
+// responses that are not critical to the RPC response itself. The failure to
+// send the (i)th FD will cause all the following FDs to not be sent as well
+// because the order in which FDs are donated is important.
+func (ch *channel) sendFDs(fds []int) uint8 {
+ numFDs := len(fds)
+ if numFDs == 0 {
+ return 0
+ }
+
+ if numFDs > math.MaxUint8 {
+ log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs)
+ return 0
+ }
+
+ for i, fd := range fds {
+ if err := ch.fdChan.SendFD(fd); err != nil {
+ log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err)
+ return uint8(i)
+ }
+ }
+ return uint8(numFDs)
+}
+
+// channelHeader is the header present in front of each message received on
+// flipcall endpoint when the protocol version being used is 1.
+//
+// +marshal
+type channelHeader struct {
+ message MID
+ numFDs uint8
+ _ uint8 // Need to make struct packed.
+}
+
+func (ch *channel) marshalHdr(m MID, numFDs uint8) {
+ header := &channelHeader{
+ message: m,
+ numFDs: numFDs,
+ }
+ header.MarshalUnsafe(ch.data.Data())
+}
+
+func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) {
+ if dataLen < chanHeaderLen {
+ log.Warningf("received data has size smaller than header length: %d", dataLen)
+ return 0, 0, unix.EIO
+ }
+
+ // Read header first.
+ var header channelHeader
+ header.UnmarshalUnsafe(ch.data.Data())
+
+ // Read any FDs.
+ for i := 0; i < int(header.numFDs); i++ {
+ fd, err := ch.fdChan.RecvFDNonblock()
+ if err != nil {
+ log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err)
+ break
+ }
+ ch.TrackFD(fd)
+ }
+
+ return header.message, dataLen - chanHeaderLen, nil
+}
diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go
new file mode 100644
index 000000000..ccf1b9f72
--- /dev/null
+++ b/pkg/lisafs/client.go
@@ -0,0 +1,432 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "math"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ // fdsToCloseBatchSize is the number of closed FDs batched before an Close
+ // RPC is made to close them all. fdsToCloseBatchSize is immutable.
+ fdsToCloseBatchSize = 100
+)
+
+// Client helps manage a connection to the lisafs server and pass messages
+// efficiently. There is a 1:1 mapping between a Connection and a Client.
+type Client struct {
+ // sockComm is the main socket by which this connections is established.
+ // Communication over the socket is synchronized by sockMu.
+ sockMu sync.Mutex
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels and availableChannels.
+ channelsMu sync.Mutex
+ // channels tracks all the channels.
+ channels []*channel
+ // availableChannels is a LIFO (stack) of channels available to be used.
+ availableChannels []*channel
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // watchdogWg only holds the watchdog goroutine.
+ watchdogWg sync.WaitGroup
+
+ // supported caches information about which messages are supported. It is
+ // indexed by MID. An MID is supported if supported[MID] is true.
+ supported []bool
+
+ // maxMessageSize is the maximum payload length (in bytes) that can be sent.
+ // It is initialized on Mount and is immutable.
+ maxMessageSize uint32
+
+ // fdsToClose tracks the FDs to close. It caches the FDs no longer being used
+ // by the client and closes them in one shot. It is not preserved across
+ // checkpoint/restore as FDIDs are not preserved.
+ fdsMu sync.Mutex
+ fdsToClose []FDID
+}
+
+// NewClient creates a new client for communication with the server. It mounts
+// the server and creates channels for fast IPC. NewClient takes ownership over
+// the passed socket. On success, it returns the initialized client along with
+// the root Inode.
+func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) {
+ maxChans := maxChannels()
+ c := &Client{
+ sockComm: newSockComm(sock),
+ channels: make([]*channel, 0, maxChans),
+ availableChannels: make([]*channel, 0, maxChans),
+ maxMessageSize: 1 << 20, // 1 MB for now.
+ fdsToClose: make([]FDID, 0, fdsToCloseBatchSize),
+ }
+
+ // Start a goroutine to check socket health. This goroutine is also
+ // responsible for client cleanup.
+ c.watchdogWg.Add(1)
+ go c.watchdog()
+
+ // Clean everything up if anything fails.
+ cu := cleanup.Make(func() {
+ c.Close()
+ })
+ defer cu.Clean()
+
+ // Mount the server first. Assume Mount is supported so that we can make the
+ // Mount RPC below.
+ c.supported = make([]bool, Mount+1)
+ c.supported[Mount] = true
+ mountMsg := MountReq{
+ MountPath: SizedString(mountPath),
+ }
+ var mountResp MountResp
+ if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil {
+ return nil, nil, err
+ }
+
+ // Initialize client.
+ c.maxMessageSize = uint32(mountResp.MaxMessageSize)
+ var maxSuppMID MID
+ for _, suppMID := range mountResp.SupportedMs {
+ if suppMID > maxSuppMID {
+ maxSuppMID = suppMID
+ }
+ }
+ c.supported = make([]bool, maxSuppMID+1)
+ for _, suppMID := range mountResp.SupportedMs {
+ c.supported[suppMID] = true
+ }
+
+ // Create channels parallely so that channels can be used to create more
+ // channels and costly initialization like flipcall.Endpoint.Connect can
+ // proceed parallely.
+ var channelsWg sync.WaitGroup
+ channelErrs := make([]error, maxChans)
+ for i := 0; i < maxChans; i++ {
+ channelsWg.Add(1)
+ curChanID := i
+ go func() {
+ defer channelsWg.Done()
+ ch, err := c.createChannel()
+ if err != nil {
+ log.Warningf("channel creation failed: %v", err)
+ channelErrs[curChanID] = err
+ return
+ }
+ c.channelsMu.Lock()
+ c.channels = append(c.channels, ch)
+ c.availableChannels = append(c.availableChannels, ch)
+ c.channelsMu.Unlock()
+ }()
+ }
+ channelsWg.Wait()
+
+ for _, channelErr := range channelErrs {
+ // Return the first non-nil channel creation error.
+ if channelErr != nil {
+ return nil, nil, channelErr
+ }
+ }
+ cu.Release()
+
+ return c, &mountResp.Root, nil
+}
+
+func (c *Client) watchdog() {
+ defer c.watchdogWg.Done()
+
+ events := []unix.PollFd{
+ {
+ Fd: int32(c.sockComm.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("lisafs.Client.watch(): %v", err)
+ } else if n != 1 {
+ log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Shutdown all active channels and wait for them to complete.
+ c.shutdownActiveChans()
+ c.activeWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ c.channelsMu.Unlock()
+
+ // Close main socket.
+ c.sockComm.destroy()
+}
+
+func (c *Client) shutdownActiveChans() {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ availableChans := make(map[*channel]bool)
+ for _, ch := range c.availableChannels {
+ availableChans[ch] = true
+ }
+ for _, ch := range c.channels {
+ // A channel that is not available is active.
+ if _, ok := availableChans[ch]; !ok {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.shutdown()
+ }
+ }
+
+ // Prevent channels from becoming available and serving new requests.
+ c.availableChannels = nil
+}
+
+// Close shuts down the main socket and waits for the watchdog to clean up.
+func (c *Client) Close() {
+ // This shutdown has no effect if the watchdog has already fired and closed
+ // the main socket.
+ c.sockComm.shutdown()
+ c.watchdogWg.Wait()
+}
+
+func (c *Client) createChannel() (*channel, error) {
+ var chanResp ChannelResp
+ var fds [2]int
+ if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil {
+ return nil, err
+ }
+ if fds[0] < 0 || fds[1] < 0 {
+ closeFDs(fds[:])
+ return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds)
+ }
+
+ // Lets create the channel.
+ defer closeFDs(fds[:1]) // The data FD is not needed after this.
+ desc := flipcall.PacketWindowDescriptor{
+ FD: fds[0],
+ Offset: chanResp.dataOffset,
+ Length: int(chanResp.dataLength),
+ }
+
+ ch := &channel{}
+ if err := ch.data.Init(flipcall.ClientSide, desc); err != nil {
+ closeFDs(fds[1:])
+ return nil, err
+ }
+ ch.fdChan.Init(fds[1]) // fdChan now owns this FD.
+
+ // Only a connected channel is usable.
+ if err := ch.data.Connect(); err != nil {
+ ch.destroy()
+ return nil, err
+ }
+ return ch, nil
+}
+
+// IsSupported returns true if this connection supports the passed message.
+func (c *Client) IsSupported(m MID) bool {
+ return int(m) < len(c.supported) && c.supported[m]
+}
+
+// CloseFDBatched either queues the passed FD to be closed or makes a batch
+// RPC to close all the accumulated FDs-to-close.
+func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) {
+ c.fdsMu.Lock()
+ c.fdsToClose = append(c.fdsToClose, fd)
+ if len(c.fdsToClose) < fdsToCloseBatchSize {
+ c.fdsMu.Unlock()
+ return
+ }
+
+ // Flush the cache. We should not hold fdsMu while making an RPC, so be sure
+ // to copy the fdsToClose to another buffer before unlocking fdsMu.
+ var toCloseArr [fdsToCloseBatchSize]FDID
+ toClose := toCloseArr[:len(c.fdsToClose)]
+ copy(toClose, c.fdsToClose)
+
+ // Clear fdsToClose so other FDIDs can be appended.
+ c.fdsToClose = c.fdsToClose[:0]
+ c.fdsMu.Unlock()
+
+ req := CloseReq{FDs: toClose}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ log.Warningf("lisafs: batch closing FDs returned error: %v", err)
+ }
+}
+
+// SyncFDs makes a Fsync RPC to sync multiple FDs.
+func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error {
+ if len(fds) == 0 {
+ return nil
+ }
+ req := FsyncReq{FDs: fds}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// SndRcvMessage invokes reqMarshal to marshal the request onto the payload
+// buffer, wakes up the server to process the request, waits for the response
+// and invokes respUnmarshal with the response payload. respFDs is populated
+// with the received FDs, extra fields are set to -1.
+//
+// Note that the function arguments intentionally accept marshal.Marshallable
+// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of
+// directly accepting the marshal.Marshallable interface. Even though just
+// accepting marshal.Marshallable is cleaner, it leads to a heap allocation
+// (even if that interface variable itself does not escape). In other words,
+// implicit conversion to an interface leads to an allocation.
+//
+// Precondition: reqMarshal and respUnmarshal must be non-nil.
+func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error {
+ if !c.IsSupported(m) {
+ return unix.EOPNOTSUPP
+ }
+ if payloadLen > c.maxMessageSize {
+ log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize)
+ return unix.EIO
+ }
+ wantFDs := len(respFDs)
+ if wantFDs > math.MaxUint8 {
+ log.Warningf("want too many FDs: %d", wantFDs)
+ return unix.EINVAL
+ }
+
+ // Acquire a communicator.
+ comm := c.acquireCommunicator()
+ defer c.releaseCommunicator(comm)
+
+ // Marshal the request into comm's payload buffer and make the RPC.
+ reqMarshal(comm.PayloadBuf(payloadLen))
+ respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs))
+
+ // Handle FD donation.
+ rcvFDs := comm.ReleaseFDs()
+ if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 {
+ // releasedFDs is memory owned by comm which can not be returned to caller.
+ // Copy it into the caller's buffer.
+ numFDCopied := copy(respFDs, rcvFDs)
+ if numFDCopied < numRcvFDs {
+ log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs)
+ closeFDs(rcvFDs[numFDCopied:])
+ }
+ if numFDCopied < wantFDs {
+ for i := numFDCopied; i < wantFDs; i++ {
+ respFDs[i] = -1
+ }
+ }
+ }
+
+ // Error cases.
+ if err != nil {
+ closeFDs(respFDs)
+ return err
+ }
+ if respM == Error {
+ closeFDs(respFDs)
+ var resp ErrorResp
+ resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen))
+ return unix.Errno(resp.errno)
+ }
+ if respM != m {
+ closeFDs(respFDs)
+ log.Warningf("sent %d message but got %d in response", m, respM)
+ return unix.EINVAL
+ }
+
+ // Success. The payload must be unmarshalled *before* comm is released.
+ respUnmarshal(comm.PayloadBuf(respPayloadLen))
+ return nil
+}
+
+// Postcondition: releaseCommunicator() must be called on the returned value.
+func (c *Client) acquireCommunicator() Communicator {
+ // Prefer using channel over socket because:
+ // - Channel uses a shared memory region for passing messages. IO from shared
+ // memory is faster and does not involve making a syscall.
+ // - No intermediate buffer allocation needed. With a channel, the message
+ // can be directly pasted into the shared memory region.
+ if ch := c.getChannel(); ch != nil {
+ return ch
+ }
+
+ c.sockMu.Lock()
+ return c.sockComm
+}
+
+// Precondition: comm must have been acquired via acquireCommunicator().
+func (c *Client) releaseCommunicator(comm Communicator) {
+ switch t := comm.(type) {
+ case *sockCommunicator:
+ c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator().
+ case *channel:
+ c.releaseChannel(t)
+ default:
+ panic(fmt.Sprintf("unknown communicator type %T", t))
+ }
+}
+
+// getChannel pops a channel from the available channels stack. The caller must
+// release the channel after use.
+func (c *Client) getChannel() *channel {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ if len(c.availableChannels) == 0 {
+ return nil
+ }
+
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ c.activeWg.Add(1)
+ return ch
+}
+
+// releaseChannel pushes the passed channel onto the available channel stack if
+// reinsert is true.
+func (c *Client) releaseChannel(ch *channel) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ // If availableChannels is nil, then watchdog has fired and the client is
+ // shutting down. So don't make this channel available again.
+ if !ch.dead && c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.activeWg.Done()
+}
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go
new file mode 100644
index 000000000..0f8788f3b
--- /dev/null
+++ b/pkg/lisafs/client_file.go
@@ -0,0 +1,475 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// ClientFD is a wrapper around FDID that provides client-side utilities
+// so that RPC making is easier.
+type ClientFD struct {
+ fd FDID
+ client *Client
+}
+
+// ID returns the underlying FDID.
+func (f *ClientFD) ID() FDID {
+ return f.fd
+}
+
+// Client returns the backing Client.
+func (f *ClientFD) Client() *Client {
+ return f.client
+}
+
+// NewFD initializes a new ClientFD.
+func (c *Client) NewFD(fd FDID) ClientFD {
+ return ClientFD{
+ client: c,
+ fd: fd,
+ }
+}
+
+// Ok returns true if the underlying FD is ok.
+func (f *ClientFD) Ok() bool {
+ return f.fd.Ok()
+}
+
+// CloseBatched queues this FD to be closed on the server and resets f.fd.
+// This maybe invoke the Close RPC if the queue is full.
+func (f *ClientFD) CloseBatched(ctx context.Context) {
+ f.client.CloseFDBatched(ctx, f.fd)
+ f.fd = InvalidFDID
+}
+
+// Close closes this FD immediately (invoking a Close RPC). Consider using
+// CloseBatched if closing this FD on remote right away is not critical.
+func (f *ClientFD) Close(ctx context.Context) error {
+ fdArr := [1]FDID{f.fd}
+ req := CloseReq{FDs: fdArr[:]}
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// OpenAt makes the OpenAt RPC.
+func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) {
+ req := OpenAtReq{
+ FD: f.fd,
+ Flags: flags,
+ }
+ var respFD [1]int
+ var resp OpenAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.NewFD, respFD[0], err
+}
+
+// OpenCreateAt makes the OpenCreateAt RPC.
+func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) {
+ var req OpenCreateAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Flags = primitive.Uint32(flags)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var respFD [1]int
+ var resp OpenCreateAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Child, resp.NewFD, respFD[0], err
+}
+
+// StatTo makes the Fstat RPC and populates stat with the result.
+func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error {
+ req := StatReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Sync makes the Fsync RPC.
+func (f *ClientFD) Sync(ctx context.Context) error {
+ req := FsyncReq{FDs: []FDID{f.fd}}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Read makes the PRead RPC.
+func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) {
+ req := PReadReq{
+ Offset: offset,
+ FD: f.fd,
+ Count: uint32(len(dst)),
+ }
+
+ resp := PReadResp{
+ // This will be unmarshalled into. Already set Buf so that we don't need to
+ // allocate a temporary buffer during unmarshalling.
+ // PReadResp.UnmarshalBytes expects this to be set.
+ Buf: dst,
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return uint64(resp.NumBytes), err
+}
+
+// Write makes the PWrite RPC.
+func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) {
+ req := PWriteReq{
+ Offset: primitive.Uint64(offset),
+ FD: f.fd,
+ NumBytes: primitive.Uint32(len(src)),
+ Buf: src,
+ }
+
+ var resp PWriteResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Count, err
+}
+
+// MkdirAt makes the MkdirAt RPC.
+func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) {
+ var req MkdirAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var resp MkdirAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.ChildDir, err
+}
+
+// SymlinkAt makes the SymlinkAt RPC.
+func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) {
+ req := SymlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Target: SizedString(target),
+ UID: uid,
+ GID: gid,
+ }
+
+ var resp SymlinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Symlink, err
+}
+
+// LinkAt makes the LinkAt RPC.
+func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) {
+ req := LinkAtReq{
+ DirFD: f.fd,
+ Target: targetFD,
+ Name: SizedString(name),
+ }
+
+ var resp LinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Link, err
+}
+
+// MknodAt makes the MknodAt RPC.
+func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) {
+ var req MknodAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+ req.Minor = primitive.Uint32(minor)
+ req.Major = primitive.Uint32(major)
+
+ var resp MknodAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Child, err
+}
+
+// SetStat makes the SetStat RPC.
+func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) {
+ req := SetStatReq{
+ FD: f.fd,
+ Mask: stat.Mask,
+ Mode: uint32(stat.Mode),
+ UID: UID(stat.UID),
+ GID: GID(stat.GID),
+ Size: stat.Size,
+ Atime: linux.Timespec{
+ Sec: stat.Atime.Sec,
+ Nsec: int64(stat.Atime.Nsec),
+ },
+ Mtime: linux.Timespec{
+ Sec: stat.Mtime.Sec,
+ Nsec: int64(stat.Mtime.Nsec),
+ },
+ }
+
+ var resp SetStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.FailureMask, unix.Errno(resp.FailureErrNo), err
+}
+
+// WalkMultiple makes the Walk RPC with multiple path components.
+func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Status, resp.Inodes, err
+}
+
+// Walk makes the Walk RPC with just one path component to walk.
+func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: []string{name},
+ }
+
+ var inode [1]Inode
+ resp := WalkResp{Inodes: inode[:]}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return nil, err
+ }
+
+ switch resp.Status {
+ case WalkComponentDoesNotExist:
+ return nil, unix.ENOENT
+ case WalkComponentSymlink:
+ // f is not a directory which can be walked on.
+ return nil, unix.ENOTDIR
+ }
+
+ if n := len(resp.Inodes); n > 1 {
+ for i := range resp.Inodes {
+ f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD)
+ }
+ log.Warningf("requested to walk one component, but got %d results", n)
+ return nil, unix.EIO
+ } else if n == 0 {
+ log.Warningf("walk has success status but no results returned")
+ return nil, unix.ENOENT
+ }
+ return &inode[0], err
+}
+
+// WalkStat makes the WalkStat RPC with multiple path components to walk.
+func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Stats, err
+}
+
+// StatFSTo makes the FStatFS RPC and populates statFS with the result.
+func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error {
+ req := FStatFSReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Allocate makes the FAllocate RPC.
+func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ req := FAllocateReq{
+ FD: f.fd,
+ Mode: mode,
+ Offset: offset,
+ Length: length,
+ }
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// ReadLinkAt makes the ReadLinkAt RPC.
+func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) {
+ req := ReadLinkAtReq{FD: f.fd}
+ var resp ReadLinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Target), err
+}
+
+// Flush makes the Flush RPC.
+func (f *ClientFD) Flush(ctx context.Context) error {
+ if !f.client.IsSupported(Flush) {
+ // If Flush is not supported, it probably means that it would be a noop.
+ return nil
+ }
+ req := FlushReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Connect makes the Connect RPC.
+func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) {
+ req := ConnectReq{FD: f.fd, SockType: uint32(sockType)}
+ var sockFD [1]int
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ if err == nil && sockFD[0] < 0 {
+ err = unix.EBADF
+ }
+ return sockFD[0], err
+}
+
+// UnlinkAt makes the UnlinkAt RPC.
+func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error {
+ req := UnlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with
+// name newName.
+func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error {
+ req := RenameAtReq{
+ Renamed: f.fd,
+ NewDir: newDirFD,
+ NewName: SizedString(newName),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Getdents64 makes the Getdents64 RPC.
+func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) {
+ req := Getdents64Req{
+ DirFD: f.fd,
+ Count: count,
+ }
+
+ var resp Getdents64Resp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Dirents, err
+}
+
+// ListXattr makes the FListXattr RPC.
+func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ req := FListXattrReq{
+ FD: f.fd,
+ Size: size,
+ }
+
+ var resp FListXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Xattrs, err
+}
+
+// GetXattr makes the FGetXattr RPC.
+func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
+ req := FGetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ BufSize: primitive.Uint32(size),
+ }
+
+ var resp FGetXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Value), err
+}
+
+// SetXattr makes the FSetXattr RPC.
+func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error {
+ req := FSetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ Value: SizedString(value),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RemoveXattr makes the FRemoveXattr RPC.
+func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error {
+ req := FRemoveXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go
new file mode 100644
index 000000000..ec2035158
--- /dev/null
+++ b/pkg/lisafs/communicator.go
@@ -0,0 +1,80 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import "golang.org/x/sys/unix"
+
+// Communicator is a server side utility which represents exactly how the
+// server is communicating with the client.
+type Communicator interface {
+ // PayloadBuf returns a slice to the payload section of its internal buffer
+ // where the message can be marshalled. The handlers should use this to
+ // populate the payload buffer with the message.
+ //
+ // The payload buffer contents *should* be preserved across calls with
+ // different sizes. Note that this is not a guarantee, because a compromised
+ // owner of a "shared" payload buffer can tamper with its contents anytime,
+ // even when it's not its turn to do so.
+ PayloadBuf(size uint32) []byte
+
+ // SndRcvMessage sends message m. The caller must have populated PayloadBuf()
+ // with payloadLen bytes. The caller expects to receive wantFDs FDs.
+ // Any received FDs must be accessible via ReleaseFDs(). It returns the
+ // response message along with the response payload length.
+ SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error)
+
+ // DonateFD makes fd non-blocking and starts tracking it. The next call to
+ // ReleaseFDs will include fd in the order it was added. Communicator takes
+ // ownership of fd. Server side should call this.
+ DonateFD(fd int) error
+
+ // Track starts tracking fd. The next call to ReleaseFDs will include fd in
+ // the order it was added. Communicator takes ownership of fd. Client side
+ // should use this for accumulating received FDs.
+ TrackFD(fd int)
+
+ // ReleaseFDs returns the accumulated FDs and stops tracking them. The
+ // ownership of the FDs is transferred to the caller.
+ ReleaseFDs() []int
+}
+
+// fdTracker is a partial implementation of Communicator. It can be embedded in
+// Communicator implementations to keep track of FD donations.
+type fdTracker struct {
+ fds []int
+}
+
+// DonateFD implements Communicator.DonateFD.
+func (d *fdTracker) DonateFD(fd int) error {
+ // Make sure the FD is non-blocking.
+ if err := unix.SetNonblock(fd, true); err != nil {
+ unix.Close(fd)
+ return err
+ }
+ d.TrackFD(fd)
+ return nil
+}
+
+// TrackFD implements Communicator.TrackFD.
+func (d *fdTracker) TrackFD(fd int) {
+ d.fds = append(d.fds, fd)
+}
+
+// ReleaseFDs implements Communicator.ReleaseFDs.
+func (d *fdTracker) ReleaseFDs() []int {
+ ret := d.fds
+ d.fds = d.fds[:0]
+ return ret
+}
diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go
new file mode 100644
index 000000000..f6e5ecb4f
--- /dev/null
+++ b/pkg/lisafs/connection.go
@@ -0,0 +1,320 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Connection represents a connection between a mount point in the client and a
+// mount point in the server. It is owned by the server on which it was started
+// and facilitates communication with the client mount.
+//
+// Each connection is set up using a unix domain socket. One end is owned by
+// the server and the other end is owned by the client. The connection may
+// spawn additional comunicational channels for the same mount for increased
+// RPC concurrency.
+type Connection struct {
+ // server is the server on which this connection was created. It is immutably
+ // associated with it for its entire lifetime.
+ server *Server
+
+ // mounted is a one way flag indicating whether this connection has been
+ // mounted correctly and the server is initialized properly.
+ mounted bool
+
+ // readonly indicates if this connection is readonly. All write operations
+ // will fail with EROFS.
+ readonly bool
+
+ // sockComm is the main socket by which this connections is established.
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+ // channels keeps track of all open channels.
+ channels []*channel
+
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // reqGate counts requests that are still being handled.
+ reqGate sync.Gate
+
+ // channelAlloc is used to allocate memory for channels.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ fdsMu sync.RWMutex
+ // fds keeps tracks of open FDs on this server. It is protected by fdsMu.
+ fds map[FDID]genericFD
+ // nextFDID is the next available FDID. It is protected by fdsMu.
+ nextFDID FDID
+}
+
+// CreateConnection initializes a new connection - creating a server if
+// required. The connection must be started separately.
+func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) {
+ c := &Connection{
+ sockComm: newSockComm(sock),
+ server: s,
+ readonly: readonly,
+ channels: make([]*channel, 0, maxChannels()),
+ fds: make(map[FDID]genericFD),
+ nextFDID: InvalidFDID + 1,
+ }
+
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return nil, err
+ }
+ c.channelAlloc = alloc
+ return c, nil
+}
+
+// Server returns the associated server.
+func (c *Connection) Server() *Server {
+ return c.server
+}
+
+// ServerImpl returns the associated server implementation.
+func (c *Connection) ServerImpl() ServerImpl {
+ return c.server.impl
+}
+
+// Run defines the lifecycle of a connection.
+func (c *Connection) Run() {
+ defer c.close()
+
+ // Start handling requests on this connection.
+ for {
+ m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
+ if err != nil {
+ log.Debugf("sock read failed, closing connection: %v", err)
+ return
+ }
+
+ respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
+ err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
+ closeFDs(respFDs)
+ if err != nil {
+ log.Debugf("sock write failed, closing connection: %v", err)
+ return
+ }
+ }
+}
+
+// service starts servicing the passed channel until the channel is shutdown.
+// This is a blocking method and hence must be called in a separate goroutine.
+func (c *Connection) service(ch *channel) error {
+ rcvDataLen, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rcvDataLen > 0 {
+ m, payloadLen, err := ch.rcvMsg(rcvDataLen)
+ if err != nil {
+ return err
+ }
+ respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
+ numFDs := ch.sendFDs(respFDs)
+ closeFDs(respFDs)
+
+ ch.marshalHdr(respM, numFDs)
+ rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
+ resp := &ErrorResp{errno: uint32(err)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return Error, respLen, nil
+}
+
+func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) {
+ if !c.reqGate.Enter() {
+ // c.close() has been called; the connection is shutting down.
+ return c.respondError(comm, unix.ECONNRESET)
+ }
+ defer c.reqGate.Leave()
+
+ if !c.mounted && m != Mount {
+ log.Warningf("connection must first be mounted")
+ return c.respondError(comm, unix.EINVAL)
+ }
+
+ // Check if the message is supported for forward compatibility.
+ if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
+ log.Warningf("received request which is not supported by the server, MID = %d", m)
+ return c.respondError(comm, unix.EOPNOTSUPP)
+ }
+
+ // Try handling the request.
+ respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
+ fds := comm.ReleaseFDs()
+ if err != nil {
+ closeFDs(fds)
+ return c.respondError(comm, p9.ExtractErrno(err))
+ }
+
+ return m, respPayloadLen, fds
+}
+
+func (c *Connection) close() {
+ // Wait for completion of all inflight requests. This is mostly so that if
+ // a request is stuck, the sandbox supervisor has the opportunity to kill
+ // us with SIGABRT to get a stack dump of the offending handler.
+ c.reqGate.Close()
+
+ // Shutdown and clean up channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.shutdown()
+ }
+ c.activeWg.Wait()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ // This is to prevent additional channels from being created.
+ c.channels = nil
+ c.channelsMu.Unlock()
+
+ // Free the channel memory.
+ if c.channelAlloc != nil {
+ c.channelAlloc.Destroy()
+ }
+
+ // Ensure the connection is closed.
+ c.sockComm.destroy()
+
+ // Cleanup all FDs.
+ c.fdsMu.Lock()
+ for fdid := range c.fds {
+ fd := c.removeFDLocked(fdid)
+ fd.DecRef(nil) // Drop the ref held by c.
+ }
+ c.fdsMu.Unlock()
+}
+
+// The caller gains a ref on the FD on success.
+func (c *Connection) lookupFD(id FDID) (genericFD, error) {
+ c.fdsMu.RLock()
+ defer c.fdsMu.RUnlock()
+
+ fd, ok := c.fds[id]
+ if !ok {
+ return nil, unix.EBADF
+ }
+ fd.IncRef()
+ return fd, nil
+}
+
+// LookupControlFD retrieves the control FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ cfd, ok := fd.(*ControlFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return cfd, nil
+}
+
+// LookupOpenFD retrieves the open FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ ofd, ok := fd.(*OpenFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return ofd, nil
+}
+
+// insertFD inserts the passed fd into the internal datastructure to track FDs.
+// The caller must hold a ref on fd which is transferred to the connection.
+func (c *Connection) insertFD(fd genericFD) FDID {
+ c.fdsMu.Lock()
+ defer c.fdsMu.Unlock()
+
+ res := c.nextFDID
+ c.nextFDID++
+ if c.nextFDID < res {
+ panic("ran out of FDIDs")
+ }
+ c.fds[res] = fd
+ return res
+}
+
+// RemoveFD makes c stop tracking the passed FDID and drops its ref on it.
+func (c *Connection) RemoveFD(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.DecRef(nil)
+ }
+}
+
+// RemoveControlFDLocked is the same as RemoveFD with added preconditions.
+//
+// Preconditions:
+// * server's rename mutex must at least be read locked.
+// * id must be pointing to a control FD.
+func (c *Connection) RemoveControlFDLocked(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.(*ControlFD).DecRefLocked()
+ }
+}
+
+// removeFDLocked makes c stop tracking the passed FDID. Note that the caller
+// must drop ref on the returned fd (preferably without holding c.fdsMu).
+//
+// Precondition: c.fdsMu is locked.
+func (c *Connection) removeFDLocked(id FDID) genericFD {
+ fd := c.fds[id]
+ if fd == nil {
+ log.Warningf("removeFDLocked called on non-existent FDID %d", id)
+ return nil
+ }
+ delete(c.fds, id)
+ return fd
+}
diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go
new file mode 100644
index 000000000..28ba47112
--- /dev/null
+++ b/pkg/lisafs/connection_test.go
@@ -0,0 +1,194 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package connection_test
+
+import (
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ dynamicMsgID = lisafs.Channel + 1
+ versionMsgID = dynamicMsgID + 1
+)
+
+var handlers = [...]lisafs.RPCHandler{
+ lisafs.Error: lisafs.ErrorHandler,
+ lisafs.Mount: lisafs.MountHandler,
+ lisafs.Channel: lisafs.ChannelHandler,
+ dynamicMsgID: dynamicMsgHandler,
+ versionMsgID: versionHandler,
+}
+
+// testServer implements lisafs.ServerImpl.
+type testServer struct {
+ lisafs.Server
+}
+
+var _ lisafs.ServerImpl = (*testServer)(nil)
+
+type testControlFD struct {
+ lisafs.ControlFD
+ lisafs.ControlFDImpl
+}
+
+func (fd *testControlFD) FD() *lisafs.ControlFD {
+ return &fd.ControlFD
+}
+
+// Mount implements lisafs.Mount.
+func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) {
+ return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil
+}
+
+// MaxMessageSize implements lisafs.MaxMessageSize.
+func (s *testServer) MaxMessageSize() uint32 {
+ return lisafs.MaxMessageSize()
+}
+
+// SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
+func (s *testServer) SupportedMessages() []lisafs.MID {
+ return []lisafs.MID{
+ lisafs.Mount,
+ lisafs.Channel,
+ dynamicMsgID,
+ versionMsgID,
+ }
+}
+
+func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) {
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ ts := &testServer{}
+ ts.Server.InitTestOnly(ts, handlers[:])
+ conn, err := ts.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ ts.StartConnection(conn)
+
+ c, _, err := lisafs.NewClient(clientSocket, "/")
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ clientFn(c)
+
+ c.Close() // This should trigger client and server shutdown.
+ ts.Wait()
+}
+
+// TestStartUp tests that the server and client can be started up correctly.
+func TestStartUp(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ if c.IsSupported(lisafs.Error) {
+ t.Errorf("sending error messages should not be supported")
+ }
+ })
+}
+
+func TestUnsupportedMessage(t *testing.T) {
+ unsupportedM := lisafs.MID(len(handlers))
+ runServerClient(t, func(c *lisafs.Client) {
+ if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP {
+ t.Errorf("expected EOPNOTSUPP but got err: %v", err)
+ }
+ })
+}
+
+func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ var req lisafs.MsgDynamic
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Just echo back the message.
+ respPayloadLen := uint32(req.SizeBytes())
+ req.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// TestStress stress tests sending many messages from various goroutines.
+func TestStress(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ concurrency := 8
+ numMsgPerGoroutine := 5000
+ var clientWg sync.WaitGroup
+ for i := 0; i < concurrency; i++ {
+ clientWg.Add(1)
+ go func() {
+ defer clientWg.Done()
+
+ for j := 0; j < numMsgPerGoroutine; j++ {
+ // Create a massive random message.
+ var req lisafs.MsgDynamic
+ req.Randomize(100)
+
+ var resp lisafs.MsgDynamic
+ if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil {
+ t.Errorf("SndRcvMessage: received unexpected error %v", err)
+ return
+ }
+ if !reflect.DeepEqual(&req, &resp) {
+ t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp)
+ }
+ }
+ }()
+ }
+
+ clientWg.Wait()
+ })
+}
+
+func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ // To be fair, usually handlers will create their own objects and return a
+ // pointer to those. Might be tempting to reuse above variables, but don't.
+ var rv lisafs.P9Version
+ rv.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Create a new response.
+ sv := lisafs.P9Version{
+ MSize: rv.MSize,
+ Version: "9P2000.L.Google.11",
+ }
+ respPayloadLen := uint32(sv.SizeBytes())
+ sv.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel.
+func BenchmarkSendRecv(b *testing.B) {
+ b.ReportAllocs()
+ sendV := lisafs.P9Version{
+ MSize: 1 << 20,
+ Version: "9P2000.L.Google.12",
+ }
+
+ var recvV lisafs.P9Version
+ runServerClient(b, func(c *lisafs.Client) {
+ for i := 0; i < b.N; i++ {
+ if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil {
+ b.Fatalf("unexpected error occurred: %v", err)
+ }
+ }
+ })
+}
diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go
new file mode 100644
index 000000000..cc6919a1b
--- /dev/null
+++ b/pkg/lisafs/fd.go
@@ -0,0 +1,374 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FDID (file descriptor identifier) is used to identify FDs on a connection.
+// Each connection has its own FDID namespace.
+//
+// +marshal slice:FDIDSlice
+type FDID uint32
+
+// InvalidFDID represents an invalid FDID.
+const InvalidFDID FDID = 0
+
+// Ok returns true if f is a valid FDID.
+func (f FDID) Ok() bool {
+ return f != InvalidFDID
+}
+
+// genericFD can represent a ControlFD or OpenFD.
+type genericFD interface {
+ refsvfs2.RefCounter
+}
+
+// A ControlFD is the gateway to the backing filesystem tree node. It is an
+// unusual concept. This exists to provide a safe way to do path-based
+// operations on the file. It performs operations that can modify the
+// filesystem tree and synchronizes these operations. See ControlFDImpl for
+// supported operations.
+//
+// It is not an inode, because multiple control FDs are allowed to exist on the
+// same file. It is not a file descriptor because it is not tied to any access
+// mode, i.e. a control FD can change its access mode based on the operation
+// being performed.
+//
+// Reference Model:
+// * When a control FD is created, the connection takes a ref on it which
+// represents the client's ref on the FD.
+// * The client can drop its ref via the Close RPC which will in turn make the
+// connection drop its ref.
+// * Each control FD holds a ref on its parent for its entire life time.
+type ControlFD struct {
+ controlFDRefs
+ controlFDEntry
+
+ // parent is the parent directory FD containing the file this FD represents.
+ // A ControlFD holds a ref on parent for its entire lifetime. If this FD
+ // represents the root, then parent is nil. parent may be a control FD from
+ // another connection (another mount point). parent is protected by the
+ // backing server's rename mutex.
+ parent *ControlFD
+
+ // name is the file path's last component name. If this FD represents the
+ // root directory, then name is the mount path. name is protected by the
+ // backing server's rename mutex.
+ name string
+
+ // children is a linked list of all children control FDs. As per reference
+ // model, all children hold a ref on this FD.
+ // children is protected by childrenMu and server's rename mutex. To have
+ // mutual exclusion, it is sufficient to:
+ // * Hold rename mutex for reading and lock childrenMu. OR
+ // * Or hold rename mutex for writing.
+ childrenMu sync.Mutex
+ children controlFDList
+
+ // openFDs is a linked list of all FDs opened on this FD. As per reference
+ // model, all open FDs hold a ref on this FD.
+ openFDsMu sync.RWMutex
+ openFDs openFDList
+
+ // All the following fields are immutable.
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // conn is the backing connection owning this FD.
+ conn *Connection
+
+ // ftype is the file type of the backing inode. ftype.FileType() == ftype.
+ ftype linux.FileMode
+
+ // impl is the control FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl ControlFDImpl
+}
+
+var _ genericFD = (*ControlFD)(nil)
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *ControlFD) DecRef(context.Context) {
+ fd.controlFDRefs.DecRef(func() {
+ if fd.parent != nil {
+ fd.conn.server.RenameMu.RLock()
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.conn.server.RenameMu.RUnlock()
+ fd.parent.DecRef(nil) // Drop the ref on the parent.
+ }
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// DecRefLocked is the same as DecRef except the added precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) DecRefLocked() {
+ fd.controlFDRefs.DecRef(func() {
+ fd.clearParentLocked()
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) clearParentLocked() {
+ if fd.parent == nil {
+ return
+ }
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.parent.DecRefLocked() // Drop the ref on the parent.
+}
+
+// Init must be called before first use of fd. It inserts fd into the
+// filesystem tree.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.controlFDRefs.InitRefs()
+ fd.conn = c
+ fd.id = c.insertFD(fd)
+ fd.name = name
+ fd.ftype = mode.FileType()
+ fd.impl = impl
+ fd.setParentLocked(parent)
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) setParentLocked(parent *ControlFD) {
+ fd.parent = parent
+ if parent != nil {
+ parent.IncRef() // Hold a ref on parent.
+ parent.childrenMu.Lock()
+ parent.children.PushBack(fd)
+ parent.childrenMu.Unlock()
+ }
+}
+
+// FileType returns the file mode only containing the file type bits.
+func (fd *ControlFD) FileType() linux.FileMode {
+ return fd.ftype
+}
+
+// IsDir indicates whether fd represents a directory.
+func (fd *ControlFD) IsDir() bool {
+ return fd.ftype == unix.S_IFDIR
+}
+
+// IsRegular indicates whether fd represents a regular file.
+func (fd *ControlFD) IsRegular() bool {
+ return fd.ftype == unix.S_IFREG
+}
+
+// IsSymlink indicates whether fd represents a symbolic link.
+func (fd *ControlFD) IsSymlink() bool {
+ return fd.ftype == unix.S_IFLNK
+}
+
+// IsSocket indicates whether fd represents a socket.
+func (fd *ControlFD) IsSocket() bool {
+ return fd.ftype == unix.S_IFSOCK
+}
+
+// NameLocked returns the backing file's last component name.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) NameLocked() string {
+ return fd.name
+}
+
+// ParentLocked returns the parent control FD. Note that parent might be a
+// control FD from another connection on this server. So its ID must not
+// returned on this connection because FDIDs are local to their connection.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) ParentLocked() ControlFDImpl {
+ if fd.parent == nil {
+ return nil
+ }
+ return fd.parent.impl
+}
+
+// ID returns fd's ID.
+func (fd *ControlFD) ID() FDID {
+ return fd.id
+}
+
+// FilePath returns the absolute path of the file fd was opened on. This is
+// expensive and must not be called on hot paths. FilePath acquires the rename
+// mutex for reading so callers should not be holding it.
+func (fd *ControlFD) FilePath() string {
+ // Lock the rename mutex for reading to ensure that the filesystem tree is not
+ // changed while we traverse it upwards.
+ fd.conn.server.RenameMu.RLock()
+ defer fd.conn.server.RenameMu.RUnlock()
+ return fd.FilePathLocked()
+}
+
+// FilePathLocked is the same as FilePath with the additional precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) FilePathLocked() string {
+ // Walk upwards and prepend name to res.
+ var res fspath.Builder
+ for fd != nil {
+ res.PrependComponent(fd.name)
+ fd = fd.parent
+ }
+ return res.String()
+}
+
+// ForEachOpenFD executes fn on each FD opened on fd.
+func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) {
+ fd.openFDsMu.RLock()
+ defer fd.openFDsMu.RUnlock()
+ for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() {
+ fn(ofd.impl)
+ }
+}
+
+// OpenFD represents an open file descriptor on the protocol. It resonates
+// closely with a Linux file descriptor. Its operations are limited to the
+// file. Its operations are not allowed to modify or traverse the filesystem
+// tree. See OpenFDImpl for the supported operations.
+//
+// Reference Model:
+// * An OpenFD takes a reference on the control FD it was opened on.
+type OpenFD struct {
+ openFDRefs
+ openFDEntry
+
+ // All the following fields are immutable.
+
+ // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref
+ // on controlFD for its entire lifetime.
+ controlFD *ControlFD
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // Access mode for this FD.
+ readable bool
+ writable bool
+
+ // impl is the open FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl OpenFDImpl
+}
+
+var _ genericFD = (*OpenFD)(nil)
+
+// ID returns fd's ID.
+func (fd *OpenFD) ID() FDID {
+ return fd.id
+}
+
+// ControlFD returns the control FD on which this FD was opened.
+func (fd *OpenFD) ControlFD() ControlFDImpl {
+ return fd.controlFD.impl
+}
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *OpenFD) DecRef(context.Context) {
+ fd.openFDRefs.DecRef(func() {
+ fd.controlFD.openFDsMu.Lock()
+ fd.controlFD.openFDs.Remove(fd)
+ fd.controlFD.openFDsMu.Unlock()
+ fd.controlFD.DecRef(nil) // Drop the ref on the control FD.
+ fd.impl.Close(fd.controlFD.conn)
+ })
+}
+
+// Init must be called before first use of fd.
+func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.openFDRefs.InitRefs()
+ fd.controlFD = cfd
+ fd.id = cfd.conn.insertFD(fd)
+ accessMode := flags & unix.O_ACCMODE
+ fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR
+ fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR
+ fd.impl = impl
+ cfd.IncRef() // Holds a ref on cfd for its lifetime.
+ cfd.openFDsMu.Lock()
+ cfd.openFDs.PushBack(fd)
+ cfd.openFDsMu.Unlock()
+}
+
+// ControlFDImpl contains implementation details for a ControlFD.
+// Implementations of ControlFDImpl should contain their associated ControlFD
+// by value as their first field.
+//
+// The operations that perform path traversal or any modification to the
+// filesystem tree must synchronize those modifications with the server's
+// rename mutex.
+type ControlFDImpl interface {
+ FD() *ControlFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error)
+ Walk(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ Open(c *Connection, comm Communicator, flags uint32) (uint32, error)
+ OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error)
+ Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error)
+ Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error)
+ Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error)
+ Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error)
+ StatFS(c *Connection, comm Communicator) (uint32, error)
+ Readlink(c *Connection, comm Communicator) (uint32, error)
+ Connect(c *Connection, comm Communicator, sockType uint32) error
+ Unlink(c *Connection, name string, flags uint32) error
+ RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error)
+ GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error)
+ SetXattr(c *Connection, name string, value string, flags uint32) error
+ ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error)
+ RemoveXattr(c *Connection, comm Communicator, name string) error
+}
+
+// OpenFDImpl contains implementation details for a OpenFD. Implementations of
+// OpenFDImpl should contain their associated OpenFD by value as their first
+// field.
+//
+// Since these operations do not perform any path traversal or any modification
+// to the filesystem tree, there is no need to synchronize with rename
+// operations.
+type OpenFDImpl interface {
+ FD() *OpenFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ Sync(c *Connection) error
+ Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error)
+ Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error)
+ Allocate(c *Connection, mode, off, length uint64) error
+ Flush(c *Connection) error
+ Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error)
+}
diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go
new file mode 100644
index 000000000..82807734d
--- /dev/null
+++ b/pkg/lisafs/handlers.go
@@ -0,0 +1,768 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+const (
+ allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC
+ setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME
+)
+
+// RPCHandler defines a handler that is invoked when the associated message is
+// received. The handler is responsible for:
+//
+// * Unmarshalling the request from the passed payload and interpreting it.
+// * Marshalling the response into the communicator's payload buffer.
+// * Return the number of payload bytes written.
+// * Donate any FDs (if needed) to comm which will in turn donate it to client.
+type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error)
+
+var handlers = [...]RPCHandler{
+ Error: ErrorHandler,
+ Mount: MountHandler,
+ Channel: ChannelHandler,
+ FStat: FStatHandler,
+ SetStat: SetStatHandler,
+ Walk: WalkHandler,
+ WalkStat: WalkStatHandler,
+ OpenAt: OpenAtHandler,
+ OpenCreateAt: OpenCreateAtHandler,
+ Close: CloseHandler,
+ FSync: FSyncHandler,
+ PWrite: PWriteHandler,
+ PRead: PReadHandler,
+ MkdirAt: MkdirAtHandler,
+ MknodAt: MknodAtHandler,
+ SymlinkAt: SymlinkAtHandler,
+ LinkAt: LinkAtHandler,
+ FStatFS: FStatFSHandler,
+ FAllocate: FAllocateHandler,
+ ReadLinkAt: ReadLinkAtHandler,
+ Flush: FlushHandler,
+ Connect: ConnectHandler,
+ UnlinkAt: UnlinkAtHandler,
+ RenameAt: RenameAtHandler,
+ Getdents64: Getdents64Handler,
+ FGetXattr: FGetXattrHandler,
+ FSetXattr: FSetXattrHandler,
+ FListXattr: FListXattrHandler,
+ FRemoveXattr: FRemoveXattrHandler,
+}
+
+// ErrorHandler handles Error message.
+func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ // Client should never send Error.
+ return 0, unix.EINVAL
+}
+
+// MountHandler handles the Mount RPC. Note that there can not be concurrent
+// executions of MountHandler on a connection because the connection enforces
+// that Mount is the first message on the connection. Only after the connection
+// has been successfully mounted can other channels be created.
+func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req MountReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ mountPath := path.Clean(string(req.MountPath))
+ if !filepath.IsAbs(mountPath) {
+ log.Warningf("mountPath %q is not absolute", mountPath)
+ return 0, unix.EINVAL
+ }
+
+ if c.mounted {
+ log.Warningf("connection has already been mounted at %q", mountPath)
+ return 0, unix.EBUSY
+ }
+
+ rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.addMountPoint(rootFD.FD())
+ c.mounted = true
+ resp := MountResp{
+ Root: rootIno,
+ SupportedMs: c.ServerImpl().SupportedMessages(),
+ MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()),
+ }
+ respPayloadLen := uint32(resp.SizeBytes())
+ resp.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// ChannelHandler handles the Channel RPC.
+func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize())
+ if err != nil {
+ return 0, err
+ }
+
+ // Start servicing the channel in a separate goroutine.
+ c.activeWg.Add(1)
+ go func() {
+ if err := c.service(ch); err != nil {
+ // Don't log shutdown error which is expected during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err)
+ }
+ }
+ c.activeWg.Done()
+ }()
+
+ clientDataFD, err := unix.Dup(desc.FD)
+ if err != nil {
+ unix.Close(fdSock)
+ ch.shutdown()
+ return 0, err
+ }
+
+ // Respond to client with successful channel creation message.
+ if err := comm.DonateFD(clientDataFD); err != nil {
+ return 0, err
+ }
+ if err := comm.DonateFD(fdSock); err != nil {
+ return 0, err
+ }
+ resp := ChannelResp{
+ dataOffset: desc.Offset,
+ dataLength: uint64(desc.Length),
+ }
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// FStatHandler handles the FStat RPC.
+func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req StatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.lookupFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ switch t := fd.(type) {
+ case *ControlFD:
+ return t.impl.Stat(c, comm)
+ case *OpenFD:
+ return t.impl.Stat(c, comm)
+ default:
+ panic(fmt.Sprintf("unknown fd type %T", t))
+ }
+}
+
+// SetStatHandler handles the SetStat RPC.
+func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+
+ var req SetStatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ if req.Mask&^setStatSupportedMask != 0 {
+ return 0, unix.EPERM
+ }
+
+ return fd.impl.SetStat(c, comm, req)
+}
+
+// WalkHandler handles the Walk RPC.
+func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ for _, name := range req.Path {
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.Walk(c, comm, req.Path)
+}
+
+// WalkStatHandler handles the WalkStat RPC.
+func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ // Note that this fd is allowed to not actually be a directory when the
+ // only path component to walk is "" (self).
+ if !fd.IsDir() {
+ if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) {
+ return 0, unix.ENOTDIR
+ }
+ }
+ for i, name := range req.Path {
+ // First component is allowed to be "".
+ if i == 0 && len(name) == 0 {
+ continue
+ }
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.WalkStat(c, comm, req.Path)
+}
+
+// OpenAtHandler handles the OpenAt RPC.
+func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req OpenAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ accessMode := req.Flags & unix.O_ACCMODE
+ trunc := req.Flags&unix.O_TRUNC != 0
+ if c.readonly && (accessMode != unix.O_RDONLY || trunc) {
+ return 0, unix.EROFS
+ }
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if fd.IsDir() {
+ // Directory is not truncatable and must be opened with O_RDONLY.
+ if accessMode != unix.O_RDONLY || trunc {
+ return 0, unix.EISDIR
+ }
+ }
+
+ return fd.impl.Open(c, comm, req.Flags)
+}
+
+// OpenCreateAtHandler handles the OpenCreateAt RPC.
+func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req OpenCreateAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags))
+}
+
+// CloseHandler handles the Close RPC.
+func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req CloseReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ for _, fd := range req.FDs {
+ c.RemoveFD(fd)
+ }
+
+ // There is no response message for this.
+ return 0, nil
+}
+
+// FSyncHandler handles the FSync RPC.
+func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FsyncReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+ for _, fdid := range req.FDs {
+ if err := c.fsyncFD(fdid); err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // There is no response message for this.
+ return 0, retErr
+}
+
+func (c *Connection) fsyncFD(id FDID) error {
+ fd, err := c.LookupOpenFD(id)
+ if err != nil {
+ return err
+ }
+ return fd.impl.Sync(c)
+}
+
+// PWriteHandler handles the PWrite RPC.
+func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req PWriteReq
+ // Note that it is an optimized Unmarshal operation which avoids any buffer
+ // allocation and copying. req.Buf just points to payload. This is safe to do
+ // as the handler owns payload and req's lifetime is limited to the handler.
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset))
+}
+
+// PReadHandler handles the PRead RPC.
+func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req PReadReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.readable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Read(c, comm, req.Offset, req.Count)
+}
+
+// MkdirAtHandler handles the MkdirAt RPC.
+func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MkdirAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name)
+}
+
+// MknodAtHandler handles the MknodAt RPC.
+func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MknodAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major))
+}
+
+// SymlinkAtHandler handles the SymlinkAt RPC.
+func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req SymlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID)
+}
+
+// LinkAtHandler handles the LinkAt RPC.
+func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req LinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ targetFD, err := c.LookupControlFD(req.Target)
+ if err != nil {
+ return 0, err
+ }
+ return targetFD.impl.Link(c, comm, fd.impl, name)
+}
+
+// FStatFSHandler handles the FStatFS RPC.
+func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FStatFSReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.StatFS(c, comm)
+}
+
+// FAllocateHandler handles the FAllocate RPC.
+func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FAllocateReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length)
+}
+
+// ReadLinkAtHandler handles the ReadLinkAt RPC.
+func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ReadLinkAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSymlink() {
+ return 0, unix.EINVAL
+ }
+ return fd.impl.Readlink(c, comm)
+}
+
+// FlushHandler handles the Flush RPC.
+func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FlushReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ return 0, fd.impl.Flush(c)
+}
+
+// ConnectHandler handles the Connect RPC.
+func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ConnectReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSocket() {
+ return 0, unix.ENOTSOCK
+ }
+ return 0, fd.impl.Connect(c, comm, req.SockType)
+}
+
+// UnlinkAtHandler handles the UnlinkAt RPC.
+func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req UnlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return 0, fd.impl.Unlink(c, name, uint32(req.Flags))
+}
+
+// RenameAtHandler handles the RenameAt RPC.
+func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req RenameAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ newName := string(req.NewName)
+ if err := checkSafeName(newName); err != nil {
+ return 0, err
+ }
+
+ renamed, err := c.LookupControlFD(req.Renamed)
+ if err != nil {
+ return 0, err
+ }
+ defer renamed.DecRef(nil)
+
+ newDir, err := c.LookupControlFD(req.NewDir)
+ if err != nil {
+ return 0, err
+ }
+ defer newDir.DecRef(nil)
+ if !newDir.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ // Hold RenameMu for writing during rename, this is important.
+ c.server.RenameMu.Lock()
+ defer c.server.RenameMu.Unlock()
+
+ if renamed.parent == nil {
+ // renamed is root.
+ return 0, unix.EBUSY
+ }
+
+ oldParentPath := renamed.parent.FilePathLocked()
+ oldPath := oldParentPath + "/" + renamed.name
+ if newName == renamed.name && oldParentPath == newDir.FilePathLocked() {
+ // Nothing to do.
+ return 0, nil
+ }
+
+ updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.forEachMountPoint(func(root *ControlFD) {
+ if !strings.HasPrefix(oldPath, root.name) {
+ return
+ }
+ pit := fspath.Parse(oldPath[len(root.name):]).Begin
+ root.renameRecursiveLocked(newDir, newName, pit, updateControlFD)
+ })
+
+ if cleanUp != nil {
+ cleanUp()
+ }
+ return 0, nil
+}
+
+// Precondition: rename mutex must be locked for writing.
+func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) {
+ if !pit.Ok() {
+ // fd should be renamed.
+ fd.clearParentLocked()
+ fd.setParentLocked(newDir)
+ fd.name = newName
+ if updateControlFD != nil {
+ updateControlFD(fd.impl)
+ }
+ return
+ }
+
+ cur := pit.String()
+ next := pit.Next()
+ // No need to hold fd.childrenMu because RenameMu is locked for writing.
+ for child := fd.children.Front(); child != nil; child = child.Next() {
+ if child.name == cur {
+ child.renameRecursiveLocked(newDir, newName, next, updateControlFD)
+ }
+ }
+}
+
+// Getdents64Handler handles the Getdents64 RPC.
+func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req Getdents64Req
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.controlFD.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ seek0 := false
+ if req.Count < 0 {
+ seek0 = true
+ req.Count = -req.Count
+ }
+ return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0)
+}
+
+// FGetXattrHandler handles the FGetXattr RPC.
+func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FGetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize))
+}
+
+// FSetXattrHandler handles the FSetXattr RPC.
+func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FSetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags))
+}
+
+// FListXattrHandler handles the FListXattr RPC.
+func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FListXattrReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.ListXattr(c, comm, req.Size)
+}
+
+// FRemoveXattrHandler handles the FRemoveXattr RPC.
+func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FRemoveXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.RemoveXattr(c, comm, string(req.Name))
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return unix.EINVAL
+}
diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go
new file mode 100644
index 000000000..4d8e956ab
--- /dev/null
+++ b/pkg/lisafs/lisafs.go
@@ -0,0 +1,18 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lisafs (LInux SAndbox FileSystem) defines the protocol for
+// filesystem RPCs between an untrusted Sandbox (client) and a trusted
+// filesystem server.
+package lisafs
diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go
new file mode 100644
index 000000000..722afd0be
--- /dev/null
+++ b/pkg/lisafs/message.go
@@ -0,0 +1,1251 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// Messages have two parts:
+// * A transport header used to decipher received messages.
+// * A byte array referred to as "payload" which contains the actual message.
+//
+// "dataLen" refers to the size of both combined.
+
+// MID (message ID) is used to identify messages to parse from payload.
+//
+// +marshal slice:MIDSlice
+type MID uint16
+
+// These constants are used to identify their corresponding message types.
+const (
+ // Error is only used in responses to pass errors to client.
+ Error MID = 0
+
+ // Mount is used to establish connection between the client and server mount
+ // point. lisafs requires that the client makes a successful Mount RPC before
+ // making other RPCs.
+ Mount MID = 1
+
+ // Channel requests to start a new communicational channel.
+ Channel MID = 2
+
+ // FStat requests the stat(2) results for a specified file.
+ FStat MID = 3
+
+ // SetStat requests to change file attributes. Note that there is no one
+ // corresponding Linux syscall. This is a conglomeration of fchmod(2),
+ // fchown(2), ftruncate(2) and futimesat(2).
+ SetStat MID = 4
+
+ // Walk requests to walk the specified path starting from the specified
+ // directory. Server-side path traversal is terminated preemptively on
+ // symlinks entries because they can cause non-linear traversal.
+ Walk MID = 5
+
+ // WalkStat is the same as Walk, except the following differences:
+ // * If the first path component is "", then it also returns stat results
+ // for the directory where the walk starts.
+ // * Does not return Inode, just the Stat results for each path component.
+ WalkStat MID = 6
+
+ // OpenAt is analogous to openat(2). It does not perform any walk. It merely
+ // duplicates the control FD with the open flags passed.
+ OpenAt MID = 7
+
+ // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags.
+ // It also returns the newly created file inode.
+ OpenCreateAt MID = 8
+
+ // Close is analogous to close(2) but can work on multiple FDs.
+ Close MID = 9
+
+ // FSync is analogous to fsync(2) but can work on multiple FDs.
+ FSync MID = 10
+
+ // PWrite is analogous to pwrite(2).
+ PWrite MID = 11
+
+ // PRead is analogous to pread(2).
+ PRead MID = 12
+
+ // MkdirAt is analogous to mkdirat(2).
+ MkdirAt MID = 13
+
+ // MknodAt is analogous to mknodat(2).
+ MknodAt MID = 14
+
+ // SymlinkAt is analogous to symlinkat(2).
+ SymlinkAt MID = 15
+
+ // LinkAt is analogous to linkat(2).
+ LinkAt MID = 16
+
+ // FStatFS is analogous to fstatfs(2).
+ FStatFS MID = 17
+
+ // FAllocate is analogous to fallocate(2).
+ FAllocate MID = 18
+
+ // ReadLinkAt is analogous to readlinkat(2).
+ ReadLinkAt MID = 19
+
+ // Flush cleans up the file state. Its behavior is implementation
+ // dependent and might not even be supported in server implementations.
+ Flush MID = 20
+
+ // Connect is loosely analogous to connect(2).
+ Connect MID = 21
+
+ // UnlinkAt is analogous to unlinkat(2).
+ UnlinkAt MID = 22
+
+ // RenameAt is loosely analogous to renameat(2).
+ RenameAt MID = 23
+
+ // Getdents64 is analogous to getdents64(2).
+ Getdents64 MID = 24
+
+ // FGetXattr is analogous to fgetxattr(2).
+ FGetXattr MID = 25
+
+ // FSetXattr is analogous to fsetxattr(2).
+ FSetXattr MID = 26
+
+ // FListXattr is analogous to flistxattr(2).
+ FListXattr MID = 27
+
+ // FRemoveXattr is analogous to fremovexattr(2).
+ FRemoveXattr MID = 28
+)
+
+const (
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MaxMessageSize is the recommended max message size that can be used by
+// connections. Server implementations may choose to use other values.
+func MaxMessageSize() uint32 {
+ // Return HugePageSize - PageSize so that when flipcall packet window is
+ // created with MaxMessageSize() + flipcall header size + channel header
+ // size, HugePageSize is allocated and can be backed by a single huge page
+ // if supported by the underlying memfd.
+ return uint32(hostarch.HugePageSize - os.Getpagesize())
+}
+
+// TODO(gvisor.dev/issue/6450): Once this is resolved:
+// * Update manual implementations and function signatures.
+// * Update RPC handlers and appropriate callers to handle errors correctly.
+// * Update manual implementations to get rid of buffer shifting.
+
+// UID represents a user ID.
+//
+// +marshal
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+//
+// +marshal
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes.
+func NoopMarshal([]byte) {}
+
+// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes.
+func NoopUnmarshal([]byte) {}
+
+// SizedString represents a string in memory. The marshalled string bytes are
+// preceded by a uint32 signifying the string length.
+type SizedString string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SizedString) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + len(*s)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SizedString) MarshalBytes(dst []byte) {
+ strLen := primitive.Uint32(len(*s))
+ strLen.MarshalUnsafe(dst)
+ dst = dst[strLen.SizeBytes():]
+ // Copy without any allocation.
+ copy(dst[:strLen], *s)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SizedString) UnmarshalBytes(src []byte) {
+ var strLen primitive.Uint32
+ strLen.UnmarshalUnsafe(src)
+ src = src[strLen.SizeBytes():]
+ // Take the hit, this leads to an allocation + memcpy. No way around it.
+ *s = SizedString(src[:strLen])
+}
+
+// StringArray represents an array of SizedStrings in memory. The marshalled
+// array data is preceded by a uint32 signifying the array length.
+type StringArray []string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *StringArray) SizeBytes() int {
+ size := (*primitive.Uint32)(nil).SizeBytes()
+ for _, str := range *s {
+ sstr := SizedString(str)
+ size += sstr.SizeBytes()
+ }
+ return size
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *StringArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*s))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ for _, str := range *s {
+ sstr := SizedString(str)
+ sstr.MarshalBytes(dst)
+ dst = dst[sstr.SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *StringArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+
+ if cap(*s) < int(arrLen) {
+ *s = make([]string, arrLen)
+ } else {
+ *s = (*s)[:arrLen]
+ }
+
+ for i := primitive.Uint32(0); i < arrLen; i++ {
+ var sstr SizedString
+ sstr.UnmarshalBytes(src)
+ src = src[sstr.SizeBytes():]
+ (*s)[i] = string(sstr)
+ }
+}
+
+// Inode represents an inode on the remote filesystem.
+//
+// +marshal slice:InodeSlice
+type Inode struct {
+ ControlFD FDID
+ _ uint32 // Need to make struct packed.
+ Stat linux.Statx
+}
+
+// MountReq represents a Mount request.
+type MountReq struct {
+ MountPath SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountReq) SizeBytes() int {
+ return m.MountPath.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountReq) MarshalBytes(dst []byte) {
+ m.MountPath.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountReq) UnmarshalBytes(src []byte) {
+ m.MountPath.UnmarshalBytes(src)
+}
+
+// MountResp represents a Mount response.
+type MountResp struct {
+ Root Inode
+ // MaxMessageSize is the maximum size of messages communicated between the
+ // client and server in bytes. This includes the communication header.
+ MaxMessageSize primitive.Uint32
+ // SupportedMs holds all the supported messages.
+ SupportedMs []MID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountResp) SizeBytes() int {
+ return m.Root.SizeBytes() +
+ m.MaxMessageSize.SizeBytes() +
+ (*primitive.Uint16)(nil).SizeBytes() +
+ (len(m.SupportedMs) * (*MID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountResp) MarshalBytes(dst []byte) {
+ m.Root.MarshalUnsafe(dst)
+ dst = dst[m.Root.SizeBytes():]
+ m.MaxMessageSize.MarshalUnsafe(dst)
+ dst = dst[m.MaxMessageSize.SizeBytes():]
+ numSupported := primitive.Uint16(len(m.SupportedMs))
+ numSupported.MarshalBytes(dst)
+ dst = dst[numSupported.SizeBytes():]
+ MarshalUnsafeMIDSlice(m.SupportedMs, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountResp) UnmarshalBytes(src []byte) {
+ m.Root.UnmarshalUnsafe(src)
+ src = src[m.Root.SizeBytes():]
+ m.MaxMessageSize.UnmarshalUnsafe(src)
+ src = src[m.MaxMessageSize.SizeBytes():]
+ var numSupported primitive.Uint16
+ numSupported.UnmarshalBytes(src)
+ src = src[numSupported.SizeBytes():]
+ m.SupportedMs = make([]MID, numSupported)
+ UnmarshalUnsafeMIDSlice(m.SupportedMs, src)
+}
+
+// ChannelResp is the response to the create channel request.
+//
+// +marshal
+type ChannelResp struct {
+ dataOffset int64
+ dataLength uint64
+}
+
+// ErrorResp is returned to represent an error while handling a request.
+//
+// +marshal
+type ErrorResp struct {
+ errno uint32
+}
+
+// StatReq requests the stat results for the specified FD.
+//
+// +marshal
+type StatReq struct {
+ FD FDID
+}
+
+// SetStatReq is used to set attributeds on FDs.
+//
+// +marshal
+type SetStatReq struct {
+ FD FDID
+ _ uint32
+ Mask uint32
+ Mode uint32 // Only permissions part is settable.
+ UID UID
+ GID GID
+ Size uint64
+ Atime linux.Timespec
+ Mtime linux.Timespec
+}
+
+// SetStatResp is used to communicate SetStat results. It contains a mask
+// representing the failed changes. It also contains the errno of the failed
+// set attribute operation. If multiple operations failed then any of those
+// errnos can be returned.
+//
+// +marshal
+type SetStatResp struct {
+ FailureMask uint32
+ FailureErrNo uint32
+}
+
+// WalkReq is used to request to walk multiple path components at once. This
+// is used for both Walk and WalkStat.
+type WalkReq struct {
+ DirFD FDID
+ Path StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkReq) SizeBytes() int {
+ return w.DirFD.SizeBytes() + w.Path.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkReq) MarshalBytes(dst []byte) {
+ w.DirFD.MarshalUnsafe(dst)
+ dst = dst[w.DirFD.SizeBytes():]
+ w.Path.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkReq) UnmarshalBytes(src []byte) {
+ w.DirFD.UnmarshalUnsafe(src)
+ src = src[w.DirFD.SizeBytes():]
+ w.Path.UnmarshalBytes(src)
+}
+
+// WalkStatus is used to indicate the reason for partial/unsuccessful server
+// side Walk operations. Please note that partial/unsuccessful walk operations
+// do not necessarily fail the RPC. The RPC is successful with a failure hint
+// which can be used by the client to infer server-side state.
+type WalkStatus = primitive.Uint8
+
+const (
+ // WalkSuccess indicates that all path components were successfully walked.
+ WalkSuccess WalkStatus = iota
+
+ // WalkComponentDoesNotExist indicates that the walk was prematurely
+ // terminated because an intermediate path component does not exist on
+ // server. The results of all previous existing path components is returned.
+ WalkComponentDoesNotExist
+
+ // WalkComponentSymlink indicates that the walk was prematurely
+ // terminated because an intermediate path component was a symlink. It is not
+ // safe to resolve symlinks remotely (unaware of mount points).
+ WalkComponentSymlink
+)
+
+// WalkResp is used to communicate the inodes walked by the server.
+type WalkResp struct {
+ Status WalkStatus
+ Inodes []Inode
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkResp) SizeBytes() int {
+ return w.Status.SizeBytes() +
+ (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkResp) MarshalBytes(dst []byte) {
+ w.Status.MarshalUnsafe(dst)
+ dst = dst[w.Status.SizeBytes():]
+
+ numInodes := primitive.Uint32(len(w.Inodes))
+ numInodes.MarshalUnsafe(dst)
+ dst = dst[numInodes.SizeBytes():]
+
+ MarshalUnsafeInodeSlice(w.Inodes, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkResp) UnmarshalBytes(src []byte) {
+ w.Status.UnmarshalUnsafe(src)
+ src = src[w.Status.SizeBytes():]
+
+ var numInodes primitive.Uint32
+ numInodes.UnmarshalUnsafe(src)
+ src = src[numInodes.SizeBytes():]
+
+ if cap(w.Inodes) < int(numInodes) {
+ w.Inodes = make([]Inode, numInodes)
+ } else {
+ w.Inodes = w.Inodes[:numInodes]
+ }
+ UnmarshalUnsafeInodeSlice(w.Inodes, src)
+}
+
+// WalkStatResp is used to communicate stat results for WalkStat.
+type WalkStatResp struct {
+ Stats []linux.Statx
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkStatResp) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkStatResp) MarshalBytes(dst []byte) {
+ numStats := primitive.Uint32(len(w.Stats))
+ numStats.MarshalUnsafe(dst)
+ dst = dst[numStats.SizeBytes():]
+
+ linux.MarshalUnsafeStatxSlice(w.Stats, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkStatResp) UnmarshalBytes(src []byte) {
+ var numStats primitive.Uint32
+ numStats.UnmarshalUnsafe(src)
+ src = src[numStats.SizeBytes():]
+
+ if cap(w.Stats) < int(numStats) {
+ w.Stats = make([]linux.Statx, numStats)
+ } else {
+ w.Stats = w.Stats[:numStats]
+ }
+ linux.UnmarshalUnsafeStatxSlice(w.Stats, src)
+}
+
+// OpenAtReq is used to open existing FDs with the specified flags.
+//
+// +marshal
+type OpenAtReq struct {
+ FD FDID
+ Flags uint32
+}
+
+// OpenAtResp is used to communicate the newly created FD.
+//
+// +marshal
+type OpenAtResp struct {
+ NewFD FDID
+}
+
+// +marshal
+type createCommon struct {
+ DirFD FDID
+ Mode linux.FileMode
+ _ uint16 // Need to make struct packed.
+ UID UID
+ GID GID
+}
+
+// OpenCreateAtReq is used to make OpenCreateAt requests.
+type OpenCreateAtReq struct {
+ createCommon
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (o *OpenCreateAtReq) SizeBytes() int {
+ return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (o *OpenCreateAtReq) MarshalBytes(dst []byte) {
+ o.createCommon.MarshalUnsafe(dst)
+ dst = dst[o.createCommon.SizeBytes():]
+ o.Name.MarshalBytes(dst)
+ dst = dst[o.Name.SizeBytes():]
+ o.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) {
+ o.createCommon.UnmarshalUnsafe(src)
+ src = src[o.createCommon.SizeBytes():]
+ o.Name.UnmarshalBytes(src)
+ src = src[o.Name.SizeBytes():]
+ o.Flags.UnmarshalUnsafe(src)
+}
+
+// OpenCreateAtResp is used to communicate successful OpenCreateAt results.
+//
+// +marshal
+type OpenCreateAtResp struct {
+ Child Inode
+ NewFD FDID
+ _ uint32 // Need to make struct packed.
+}
+
+// FdArray is a utility struct which implements a marshallable type for
+// communicating an array of FDIDs. In memory, the array data is preceded by a
+// uint32 denoting the array length.
+type FdArray []FDID
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FdArray) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FdArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*f))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ MarshalUnsafeFDIDSlice(*f, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FdArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+ if cap(*f) < int(arrLen) {
+ *f = make(FdArray, arrLen)
+ } else {
+ *f = (*f)[:arrLen]
+ }
+ UnmarshalUnsafeFDIDSlice(*f, src)
+}
+
+// CloseReq is used to close(2) FDs.
+type CloseReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (c *CloseReq) SizeBytes() int {
+ return c.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (c *CloseReq) MarshalBytes(dst []byte) {
+ c.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (c *CloseReq) UnmarshalBytes(src []byte) {
+ c.FDs.UnmarshalBytes(src)
+}
+
+// FsyncReq is used to fsync(2) FDs.
+type FsyncReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FsyncReq) SizeBytes() int {
+ return f.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FsyncReq) MarshalBytes(dst []byte) {
+ f.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FsyncReq) UnmarshalBytes(src []byte) {
+ f.FDs.UnmarshalBytes(src)
+}
+
+// PReadReq is used to pread(2) on an FD.
+//
+// +marshal
+type PReadReq struct {
+ Offset uint64
+ FD FDID
+ Count uint32
+}
+
+// PReadResp is used to return the result of pread(2).
+type PReadResp struct {
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *PReadResp) SizeBytes() int {
+ return r.NumBytes.SizeBytes() + int(r.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *PReadResp) MarshalBytes(dst []byte) {
+ r.NumBytes.MarshalUnsafe(dst)
+ dst = dst[r.NumBytes.SizeBytes():]
+ copy(dst[:r.NumBytes], r.Buf[:r.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *PReadResp) UnmarshalBytes(src []byte) {
+ r.NumBytes.UnmarshalUnsafe(src)
+ src = src[r.NumBytes.SizeBytes():]
+
+ // We expect the client to have already allocated r.Buf. r.Buf probably
+ // (optimally) points to usermem. Directly copy into that.
+ copy(r.Buf[:r.NumBytes], src[:r.NumBytes])
+}
+
+// PWriteReq is used to pwrite(2) on an FD.
+type PWriteReq struct {
+ Offset primitive.Uint64
+ FD FDID
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *PWriteReq) SizeBytes() int {
+ return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *PWriteReq) MarshalBytes(dst []byte) {
+ w.Offset.MarshalUnsafe(dst)
+ dst = dst[w.Offset.SizeBytes():]
+ w.FD.MarshalUnsafe(dst)
+ dst = dst[w.FD.SizeBytes():]
+ w.NumBytes.MarshalUnsafe(dst)
+ dst = dst[w.NumBytes.SizeBytes():]
+ copy(dst[:w.NumBytes], w.Buf[:w.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *PWriteReq) UnmarshalBytes(src []byte) {
+ w.Offset.UnmarshalUnsafe(src)
+ src = src[w.Offset.SizeBytes():]
+ w.FD.UnmarshalUnsafe(src)
+ src = src[w.FD.SizeBytes():]
+ w.NumBytes.UnmarshalUnsafe(src)
+ src = src[w.NumBytes.SizeBytes():]
+
+ // This is an optimization. Assuming that the server is making this call, it
+ // is safe to just point to src rather than allocating and copying.
+ w.Buf = src[:w.NumBytes]
+}
+
+// PWriteResp is used to return the result of pwrite(2).
+//
+// +marshal
+type PWriteResp struct {
+ Count uint64
+}
+
+// MkdirAtReq is used to make MkdirAt requests.
+type MkdirAtReq struct {
+ createCommon
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MkdirAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MkdirAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MkdirAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+}
+
+// MkdirAtResp is the response to a successful MkdirAt request.
+//
+// +marshal
+type MkdirAtResp struct {
+ ChildDir Inode
+}
+
+// MknodAtReq is used to make MknodAt requests.
+type MknodAtReq struct {
+ createCommon
+ Name SizedString
+ Minor primitive.Uint32
+ Major primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MknodAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MknodAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+ dst = dst[m.Name.SizeBytes():]
+ m.Minor.MarshalUnsafe(dst)
+ dst = dst[m.Minor.SizeBytes():]
+ m.Major.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MknodAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+ src = src[m.Name.SizeBytes():]
+ m.Minor.UnmarshalUnsafe(src)
+ src = src[m.Minor.SizeBytes():]
+ m.Major.UnmarshalUnsafe(src)
+}
+
+// MknodAtResp is the response to a successful MknodAt request.
+//
+// +marshal
+type MknodAtResp struct {
+ Child Inode
+}
+
+// SymlinkAtReq is used to make SymlinkAt request.
+type SymlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Target SizedString
+ UID UID
+ GID GID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SymlinkAtReq) SizeBytes() int {
+ return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SymlinkAtReq) MarshalBytes(dst []byte) {
+ s.DirFD.MarshalUnsafe(dst)
+ dst = dst[s.DirFD.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Target.MarshalBytes(dst)
+ dst = dst[s.Target.SizeBytes():]
+ s.UID.MarshalUnsafe(dst)
+ dst = dst[s.UID.SizeBytes():]
+ s.GID.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SymlinkAtReq) UnmarshalBytes(src []byte) {
+ s.DirFD.UnmarshalUnsafe(src)
+ src = src[s.DirFD.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Target.UnmarshalBytes(src)
+ src = src[s.Target.SizeBytes():]
+ s.UID.UnmarshalUnsafe(src)
+ src = src[s.UID.SizeBytes():]
+ s.GID.UnmarshalUnsafe(src)
+}
+
+// SymlinkAtResp is the response to a successful SymlinkAt request.
+//
+// +marshal
+type SymlinkAtResp struct {
+ Symlink Inode
+}
+
+// LinkAtReq is used to make LinkAt requests.
+type LinkAtReq struct {
+ DirFD FDID
+ Target FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *LinkAtReq) SizeBytes() int {
+ return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *LinkAtReq) MarshalBytes(dst []byte) {
+ l.DirFD.MarshalUnsafe(dst)
+ dst = dst[l.DirFD.SizeBytes():]
+ l.Target.MarshalUnsafe(dst)
+ dst = dst[l.Target.SizeBytes():]
+ l.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *LinkAtReq) UnmarshalBytes(src []byte) {
+ l.DirFD.UnmarshalUnsafe(src)
+ src = src[l.DirFD.SizeBytes():]
+ l.Target.UnmarshalUnsafe(src)
+ src = src[l.Target.SizeBytes():]
+ l.Name.UnmarshalBytes(src)
+}
+
+// LinkAtResp is used to respond to a successful LinkAt request.
+//
+// +marshal
+type LinkAtResp struct {
+ Link Inode
+}
+
+// FStatFSReq is used to request StatFS results for the specified FD.
+//
+// +marshal
+type FStatFSReq struct {
+ FD FDID
+}
+
+// StatFS is responded to a successful FStatFS request.
+//
+// +marshal
+type StatFS struct {
+ Type uint64
+ BlockSize int64
+ Blocks uint64
+ BlocksFree uint64
+ BlocksAvailable uint64
+ Files uint64
+ FilesFree uint64
+ NameLength uint64
+}
+
+// FAllocateReq is used to request to fallocate(2) an FD. This has no response.
+//
+// +marshal
+type FAllocateReq struct {
+ FD FDID
+ _ uint32
+ Mode uint64
+ Offset uint64
+ Length uint64
+}
+
+// ReadLinkAtReq is used to readlinkat(2) at the specified FD.
+//
+// +marshal
+type ReadLinkAtReq struct {
+ FD FDID
+}
+
+// ReadLinkAtResp is used to communicate ReadLinkAt results.
+type ReadLinkAtResp struct {
+ Target SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *ReadLinkAtResp) SizeBytes() int {
+ return r.Target.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *ReadLinkAtResp) MarshalBytes(dst []byte) {
+ r.Target.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) {
+ r.Target.UnmarshalBytes(src)
+}
+
+// FlushReq is used to make Flush requests.
+//
+// +marshal
+type FlushReq struct {
+ FD FDID
+}
+
+// ConnectReq is used to make a Connect request.
+//
+// +marshal
+type ConnectReq struct {
+ FD FDID
+ // SockType is used to specify the socket type to connect to. As a special
+ // case, SockType = 0 means that the socket type does not matter and the
+ // requester will accept any socket type.
+ SockType uint32
+}
+
+// UnlinkAtReq is used to make UnlinkAt request.
+type UnlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (u *UnlinkAtReq) SizeBytes() int {
+ return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (u *UnlinkAtReq) MarshalBytes(dst []byte) {
+ u.DirFD.MarshalUnsafe(dst)
+ dst = dst[u.DirFD.SizeBytes():]
+ u.Name.MarshalBytes(dst)
+ dst = dst[u.Name.SizeBytes():]
+ u.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (u *UnlinkAtReq) UnmarshalBytes(src []byte) {
+ u.DirFD.UnmarshalUnsafe(src)
+ src = src[u.DirFD.SizeBytes():]
+ u.Name.UnmarshalBytes(src)
+ src = src[u.Name.SizeBytes():]
+ u.Flags.UnmarshalUnsafe(src)
+}
+
+// RenameAtReq is used to make Rename requests. Note that the request takes in
+// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2).
+type RenameAtReq struct {
+ Renamed FDID
+ NewDir FDID
+ NewName SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *RenameAtReq) SizeBytes() int {
+ return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *RenameAtReq) MarshalBytes(dst []byte) {
+ r.Renamed.MarshalUnsafe(dst)
+ dst = dst[r.Renamed.SizeBytes():]
+ r.NewDir.MarshalUnsafe(dst)
+ dst = dst[r.NewDir.SizeBytes():]
+ r.NewName.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *RenameAtReq) UnmarshalBytes(src []byte) {
+ r.Renamed.UnmarshalUnsafe(src)
+ src = src[r.Renamed.SizeBytes():]
+ r.NewDir.UnmarshalUnsafe(src)
+ src = src[r.NewDir.SizeBytes():]
+ r.NewName.UnmarshalBytes(src)
+}
+
+// Getdents64Req is used to make Getdents64 requests.
+//
+// +marshal
+type Getdents64Req struct {
+ DirFD FDID
+ // Count is the number of bytes to read. A negative value of Count is used to
+ // indicate that the implementation must lseek(0, SEEK_SET) before calling
+ // getdents64(2). Implementations must use the absolute value of Count to
+ // determine the number of bytes to read.
+ Count int32
+}
+
+// Dirent64 is analogous to struct linux_dirent64.
+type Dirent64 struct {
+ Ino primitive.Uint64
+ DevMinor primitive.Uint32
+ DevMajor primitive.Uint32
+ Off primitive.Uint64
+ Type primitive.Uint8
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (d *Dirent64) SizeBytes() int {
+ return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (d *Dirent64) MarshalBytes(dst []byte) {
+ d.Ino.MarshalUnsafe(dst)
+ dst = dst[d.Ino.SizeBytes():]
+ d.DevMinor.MarshalUnsafe(dst)
+ dst = dst[d.DevMinor.SizeBytes():]
+ d.DevMajor.MarshalUnsafe(dst)
+ dst = dst[d.DevMajor.SizeBytes():]
+ d.Off.MarshalUnsafe(dst)
+ dst = dst[d.Off.SizeBytes():]
+ d.Type.MarshalUnsafe(dst)
+ dst = dst[d.Type.SizeBytes():]
+ d.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (d *Dirent64) UnmarshalBytes(src []byte) {
+ d.Ino.UnmarshalUnsafe(src)
+ src = src[d.Ino.SizeBytes():]
+ d.DevMinor.UnmarshalUnsafe(src)
+ src = src[d.DevMinor.SizeBytes():]
+ d.DevMajor.UnmarshalUnsafe(src)
+ src = src[d.DevMajor.SizeBytes():]
+ d.Off.UnmarshalUnsafe(src)
+ src = src[d.Off.SizeBytes():]
+ d.Type.UnmarshalUnsafe(src)
+ src = src[d.Type.SizeBytes():]
+ d.Name.UnmarshalBytes(src)
+}
+
+// Getdents64Resp is used to communicate getdents64 results.
+type Getdents64Resp struct {
+ Dirents []Dirent64
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *Getdents64Resp) SizeBytes() int {
+ ret := (*primitive.Uint32)(nil).SizeBytes()
+ for i := range g.Dirents {
+ ret += g.Dirents[i].SizeBytes()
+ }
+ return ret
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *Getdents64Resp) MarshalBytes(dst []byte) {
+ numDirents := primitive.Uint32(len(g.Dirents))
+ numDirents.MarshalUnsafe(dst)
+ dst = dst[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].MarshalBytes(dst)
+ dst = dst[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *Getdents64Resp) UnmarshalBytes(src []byte) {
+ var numDirents primitive.Uint32
+ numDirents.UnmarshalUnsafe(src)
+ if cap(g.Dirents) < int(numDirents) {
+ g.Dirents = make([]Dirent64, numDirents)
+ } else {
+ g.Dirents = g.Dirents[:numDirents]
+ }
+
+ src = src[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].UnmarshalBytes(src)
+ src = src[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// FGetXattrReq is used to make FGetXattr requests. The response to this is
+// just a SizedString containing the xattr value.
+type FGetXattrReq struct {
+ FD FDID
+ BufSize primitive.Uint32
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrReq) SizeBytes() int {
+ return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrReq) MarshalBytes(dst []byte) {
+ g.FD.MarshalUnsafe(dst)
+ dst = dst[g.FD.SizeBytes():]
+ g.BufSize.MarshalUnsafe(dst)
+ dst = dst[g.BufSize.SizeBytes():]
+ g.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrReq) UnmarshalBytes(src []byte) {
+ g.FD.UnmarshalUnsafe(src)
+ src = src[g.FD.SizeBytes():]
+ g.BufSize.UnmarshalUnsafe(src)
+ src = src[g.BufSize.SizeBytes():]
+ g.Name.UnmarshalBytes(src)
+}
+
+// FGetXattrResp is used to respond to FGetXattr request.
+type FGetXattrResp struct {
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrResp) SizeBytes() int {
+ return g.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrResp) MarshalBytes(dst []byte) {
+ g.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrResp) UnmarshalBytes(src []byte) {
+ g.Value.UnmarshalBytes(src)
+}
+
+// FSetXattrReq is used to make FSetXattr requests. It has no response.
+type FSetXattrReq struct {
+ FD FDID
+ Flags primitive.Uint32
+ Name SizedString
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *FSetXattrReq) SizeBytes() int {
+ return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *FSetXattrReq) MarshalBytes(dst []byte) {
+ s.FD.MarshalUnsafe(dst)
+ dst = dst[s.FD.SizeBytes():]
+ s.Flags.MarshalUnsafe(dst)
+ dst = dst[s.Flags.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *FSetXattrReq) UnmarshalBytes(src []byte) {
+ s.FD.UnmarshalUnsafe(src)
+ src = src[s.FD.SizeBytes():]
+ s.Flags.UnmarshalUnsafe(src)
+ src = src[s.Flags.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Value.UnmarshalBytes(src)
+}
+
+// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response.
+type FRemoveXattrReq struct {
+ FD FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *FRemoveXattrReq) SizeBytes() int {
+ return r.FD.SizeBytes() + r.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *FRemoveXattrReq) MarshalBytes(dst []byte) {
+ r.FD.MarshalUnsafe(dst)
+ dst = dst[r.FD.SizeBytes():]
+ r.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) {
+ r.FD.UnmarshalUnsafe(src)
+ src = src[r.FD.SizeBytes():]
+ r.Name.UnmarshalBytes(src)
+}
+
+// FListXattrReq is used to make FListXattr requests.
+//
+// +marshal
+type FListXattrReq struct {
+ FD FDID
+ _ uint32
+ Size uint64
+}
+
+// FListXattrResp is used to respond to FListXattr requests.
+type FListXattrResp struct {
+ Xattrs StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *FListXattrResp) SizeBytes() int {
+ return l.Xattrs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *FListXattrResp) MarshalBytes(dst []byte) {
+ l.Xattrs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *FListXattrResp) UnmarshalBytes(src []byte) {
+ l.Xattrs.UnmarshalBytes(src)
+}
diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go
new file mode 100644
index 000000000..3868dfa08
--- /dev/null
+++ b/pkg/lisafs/sample_message.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// MsgSimple is a sample packed struct which can be used to test message passing.
+//
+// +marshal slice:Msg1Slice
+type MsgSimple struct {
+ A uint16
+ B uint16
+ C uint32
+ D uint64
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgSimple) Randomize() {
+ m.A = uint16(rand.Uint32())
+ m.B = uint16(rand.Uint32())
+ m.C = rand.Uint32()
+ m.D = rand.Uint64()
+}
+
+// MsgDynamic is a sample dynamic struct which can be used to test message passing.
+//
+// +marshal dynamic
+type MsgDynamic struct {
+ N primitive.Uint32
+ Arr []MsgSimple
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsgDynamic) SizeBytes() int {
+ return m.N.SizeBytes() +
+ (int(m.N) * (*MsgSimple)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsgDynamic) MarshalBytes(dst []byte) {
+ m.N.MarshalUnsafe(dst)
+ dst = dst[m.N.SizeBytes():]
+ MarshalUnsafeMsg1Slice(m.Arr, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsgDynamic) UnmarshalBytes(src []byte) {
+ m.N.UnmarshalUnsafe(src)
+ src = src[m.N.SizeBytes():]
+ m.Arr = make([]MsgSimple, m.N)
+ UnmarshalUnsafeMsg1Slice(m.Arr, src)
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgDynamic) Randomize(arrLen int) {
+ m.N = primitive.Uint32(arrLen)
+ m.Arr = make([]MsgSimple, arrLen)
+ for i := 0; i < arrLen; i++ {
+ m.Arr[i].Randomize()
+ }
+}
+
+// P9Version mimics p9.TVersion and p9.Rversion.
+//
+// +marshal dynamic
+type P9Version struct {
+ MSize primitive.Uint32
+ Version string
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (v *P9Version) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (v *P9Version) MarshalBytes(dst []byte) {
+ v.MSize.MarshalUnsafe(dst)
+ dst = dst[v.MSize.SizeBytes():]
+ versionLen := primitive.Uint16(len(v.Version))
+ versionLen.MarshalUnsafe(dst)
+ dst = dst[versionLen.SizeBytes():]
+ copy(dst, v.Version)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (v *P9Version) UnmarshalBytes(src []byte) {
+ v.MSize.UnmarshalUnsafe(src)
+ src = src[v.MSize.SizeBytes():]
+ var versionLen primitive.Uint16
+ versionLen.UnmarshalUnsafe(src)
+ src = src[versionLen.SizeBytes():]
+ v.Version = string(src[:versionLen])
+}
diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go
new file mode 100644
index 000000000..7515355ec
--- /dev/null
+++ b/pkg/lisafs/server.go
@@ -0,0 +1,113 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Server serves a filesystem tree. Multiple connections on different mount
+// points can be started on a server. The server provides utilities to safely
+// modify the filesystem tree across its connections (mount points). Note that
+// it does not support synchronizing filesystem tree mutations across other
+// servers serving the same filesystem subtree. Server also manages the
+// lifecycle of all connections.
+type Server struct {
+ // connWg counts the number of active connections being tracked.
+ connWg sync.WaitGroup
+
+ // RenameMu synchronizes rename operations within this filesystem tree.
+ RenameMu sync.RWMutex
+
+ // handlers is a list of RPC handlers which can be indexed by the handler's
+ // corresponding MID.
+ handlers []RPCHandler
+
+ // mountPoints keeps track of all the mount points this server serves.
+ mpMu sync.RWMutex
+ mountPoints []*ControlFD
+
+ // impl is the server implementation which embeds this server.
+ impl ServerImpl
+}
+
+// Init must be called before first use of server.
+func (s *Server) Init(impl ServerImpl) {
+ s.impl = impl
+ s.handlers = handlers[:]
+}
+
+// InitTestOnly is the same as Init except that it allows to swap out the
+// underlying handlers with something custom. This is for test only.
+func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) {
+ s.impl = impl
+ s.handlers = handlers
+}
+
+// WithRenameReadLock invokes fn with the server's rename mutex locked for
+// reading. This ensures that no rename operations occur concurrently.
+func (s *Server) WithRenameReadLock(fn func() error) error {
+ s.RenameMu.RLock()
+ err := fn()
+ s.RenameMu.RUnlock()
+ return err
+}
+
+// StartConnection starts the connection on a separate goroutine and tracks it.
+func (s *Server) StartConnection(c *Connection) {
+ s.connWg.Add(1)
+ go func() {
+ c.Run()
+ s.connWg.Done()
+ }()
+}
+
+// Wait waits for all connections started via StartConnection() to terminate.
+func (s *Server) Wait() {
+ s.connWg.Wait()
+}
+
+func (s *Server) addMountPoint(root *ControlFD) {
+ s.mpMu.Lock()
+ defer s.mpMu.Unlock()
+ s.mountPoints = append(s.mountPoints, root)
+}
+
+func (s *Server) forEachMountPoint(fn func(root *ControlFD)) {
+ s.mpMu.RLock()
+ defer s.mpMu.RUnlock()
+ for _, mp := range s.mountPoints {
+ fn(mp)
+ }
+}
+
+// ServerImpl contains the implementation details for a Server.
+// Implementations of ServerImpl should contain their associated Server by
+// value as their first field.
+type ServerImpl interface {
+ // Mount is called when a Mount RPC is made. It mounts the connection at
+ // mountPath.
+ //
+ // Precondition: mountPath == path.Clean(mountPath).
+ Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error)
+
+ // SupportedMessages returns a list of messages that the server
+ // implementation supports.
+ SupportedMessages() []MID
+
+ // MaxMessageSize is the maximum payload length (in bytes) that can be sent
+ // to this server implementation.
+ MaxMessageSize() uint32
+}
diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go
new file mode 100644
index 000000000..88210242f
--- /dev/null
+++ b/pkg/lisafs/sock.go
@@ -0,0 +1,208 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+var (
+ sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
+)
+
+// sockHeader is the header present in front of each message received on a UDS.
+//
+// +marshal
+type sockHeader struct {
+ payloadLen uint32
+ message MID
+ _ uint16 // Need to make struct packed.
+}
+
+// sockCommunicator implements Communicator. This is not thread safe.
+type sockCommunicator struct {
+ fdTracker
+ sock *unet.Socket
+ buf []byte
+}
+
+var _ Communicator = (*sockCommunicator)(nil)
+
+func newSockComm(sock *unet.Socket) *sockCommunicator {
+ return &sockCommunicator{
+ sock: sock,
+ buf: make([]byte, sockHeaderLen),
+ }
+}
+
+func (s *sockCommunicator) FD() int {
+ return s.sock.FD()
+}
+
+func (s *sockCommunicator) destroy() {
+ s.sock.Close()
+}
+
+func (s *sockCommunicator) shutdown() {
+ if err := s.sock.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
+ }
+}
+
+func (s *sockCommunicator) resizeBuf(size uint32) {
+ if cap(s.buf) < int(size) {
+ s.buf = s.buf[:cap(s.buf)]
+ s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
+ } else {
+ s.buf = s.buf[:size]
+ }
+}
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
+ s.resizeBuf(sockHeaderLen + size)
+ return s.buf[sockHeaderLen : sockHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
+ return 0, 0, err
+ }
+
+ return s.rcvMsg(wantFDs)
+}
+
+// sndPrepopulatedMsg assumes that s.buf has already been populated with
+// `payloadLen` bytes of data.
+func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
+ header := sockHeader{payloadLen: payloadLen, message: m}
+ header.MarshalUnsafe(s.buf)
+ dataLen := sockHeaderLen + payloadLen
+ return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
+}
+
+// writeTo writes the passed iovec to the UDS and donates any passed FDs.
+func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
+ w := sock.Writer(true)
+ if len(fds) > 0 {
+ w.PackFDs(fds...)
+ }
+
+ fdsUnpacked := false
+ for n := 0; n < dataLen; {
+ cur, err := w.WriteVec(iovec)
+ if err != nil {
+ return err
+ }
+ n += cur
+
+ // Fast common path.
+ if n >= dataLen {
+ break
+ }
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(iovec[0]) <= cur-consumed {
+ consumed += len(iovec[0])
+ iovec = iovec[1:]
+ } else {
+ iovec[0] = iovec[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && !fdsUnpacked {
+ // Don't resend any control message.
+ fdsUnpacked = true
+ w.UnpackFDs()
+ }
+ }
+ return nil
+}
+
+// rcvMsg reads the message header and payload from the UDS. It also populates
+// fds with any donated FDs.
+func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
+ fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
+ if err != nil {
+ return 0, 0, err
+ }
+ for _, fd := range fds {
+ s.TrackFD(fd)
+ }
+
+ var header sockHeader
+ header.UnmarshalUnsafe(s.buf)
+
+ // No payload? We are done.
+ if header.payloadLen == 0 {
+ return header.message, 0, nil
+ }
+
+ if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
+ return 0, 0, err
+ }
+
+ return header.message, header.payloadLen, nil
+}
+
+// readFrom fills the passed buffer with data from the socket. It also returns
+// any donated FDs.
+func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
+ r := sock.Reader(true)
+ r.EnableFDs(int(wantFDs))
+
+ var (
+ fds []int
+ fdInit bool
+ )
+ n := len(buf)
+ for got := 0; got < n; {
+ cur, err := r.ReadVec([][]byte{buf[got:]})
+
+ // Ignore EOF if cur > 0.
+ if err != nil && (err != io.EOF || cur == 0) {
+ r.CloseFDs()
+ return nil, err
+ }
+
+ if !fdInit && cur > 0 {
+ fds, err = r.ExtractFDs()
+ if err != nil {
+ return nil, err
+ }
+
+ fdInit = true
+ r.EnableFDs(0)
+ }
+
+ got += cur
+ }
+ return fds, nil
+}
+
+func closeFDs(fds []int) {
+ for _, fd := range fds {
+ if fd >= 0 {
+ unix.Close(fd)
+ }
+ }
+}
diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go
new file mode 100644
index 000000000..387f4b7a8
--- /dev/null
+++ b/pkg/lisafs/sock_test.go
@@ -0,0 +1,217 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "bytes"
+ "math/rand"
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) {
+ sock1, sock2, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer sock1.Close()
+ defer sock2.Close()
+
+ var testWg sync.WaitGroup
+ testWg.Add(2)
+
+ go func() {
+ fun1(newSockComm(sock1))
+ testWg.Done()
+ }()
+
+ go func() {
+ fun2(newSockComm(sock2))
+ testWg.Done()
+ }()
+
+ testWg.Wait()
+}
+
+func TestReadWrite(t *testing.T) {
+ // Create random data to send.
+ n := 10000
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ // Scatter that data into two parts using Iovecs while sending.
+ mid := n / 2
+ if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ if _, err := readFrom(comm.sock, gotData, 0); err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+
+ // Make sure we got the right data.
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+ })
+}
+
+func TestFDDonation(t *testing.T) {
+ n := 10
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ // Try donating FDs to these files.
+ path1 := "/dev/null"
+ path2 := "/dev"
+ path3 := "/dev/random"
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0)
+ defer unix.Close(devNullFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path1, err)
+ }
+ devFD, err := unix.Open(path2, unix.O_RDONLY, 0)
+ defer unix.Close(devFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0)
+ defer unix.Close(devRandomFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ fds, err := readFrom(comm.sock, gotData, 3)
+ if err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+ defer closeFDs(fds[:])
+
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+
+ if len(fds) != 3 {
+ t.Fatalf("wanted 3 FD, got %d", len(fds))
+ }
+
+ // Check that the FDs actually point to the correct file.
+ compareFDWithFile(t, fds[0], path1)
+ compareFDWithFile(t, fds[1], path2)
+ compareFDWithFile(t, fds[2], path3)
+ })
+}
+
+func compareFDWithFile(t *testing.T, fd int, path string) {
+ var want unix.Stat_t
+ if err := unix.Stat(path, &want); err != nil {
+ t.Fatalf("stat(%s) failed: %v", path, err)
+ }
+
+ var got unix.Stat_t
+ if err := unix.Fstat(fd, &got); err != nil {
+ t.Fatalf("fstat on donated FD failed: %v", err)
+ }
+
+ if got.Ino != want.Ino || got.Dev != want.Dev {
+ t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got)
+ }
+}
+
+func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error {
+ var payloadLen uint32
+ if msg != nil {
+ payloadLen = uint32(msg.SizeBytes())
+ msg.MarshalUnsafe(comm.PayloadBuf(payloadLen))
+ }
+ return comm.sndPrepopulatedMsg(m, payloadLen, nil)
+}
+
+func TestSndRcvMessage(t *testing.T) {
+ req := &MsgSimple{}
+ req.Randomize()
+ reqM := MID(1)
+
+ // Create a massive random response.
+ var resp MsgDynamic
+ resp.Randomize(100)
+ respM := MID(2)
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, req); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, &resp)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, req)
+ if err := testSndMsg(comm, respM, &resp); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func TestSndRcvMessageNoPayload(t *testing.T) {
+ reqM := MID(1)
+ respM := MID(2)
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, nil)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, nil)
+ if err := testSndMsg(comm, respM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) {
+ gotM, payloadLen, err := comm.rcvMsg(0)
+ if err != nil {
+ t.Fatalf("readMessageFrom failed: %v", err)
+ }
+ if gotM != wantM {
+ t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM)
+ }
+ if wantMsg == nil {
+ if payloadLen != 0 {
+ t.Errorf("no payload expect but got %d bytes", payloadLen)
+ }
+ } else {
+ gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable)
+ gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+ if !reflect.DeepEqual(wantMsg, gotMsg) {
+ t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg)
+ }
+ }
+}
diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD
new file mode 100644
index 000000000..b4a542b3a
--- /dev/null
+++ b/pkg/lisafs/testsuite/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testsuite",
+ testonly = True,
+ srcs = ["testsuite.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/lisafs",
+ "//pkg/unet",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go
new file mode 100644
index 000000000..476ff76a5
--- /dev/null
+++ b/pkg/lisafs/testsuite/testsuite.go
@@ -0,0 +1,637 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testsuite provides a integration testing suite for lisafs.
+// These tests are intended for servers serving the local filesystem.
+package testsuite
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/syndtr/gocapability/capability"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Tester is the client code using this test suite. This interface abstracts
+// away all the caller specific details.
+type Tester interface {
+ // NewServer returns a new instance of the tester server.
+ NewServer(t *testing.T) *lisafs.Server
+
+ // LinkSupported returns true if the backing server supports LinkAt.
+ LinkSupported() bool
+
+ // SetUserGroupIDSupported returns true if the backing server supports
+ // changing UID/GID for files.
+ SetUserGroupIDSupported() bool
+}
+
+// RunAllLocalFSTests runs all local FS tests as subtests.
+func RunAllLocalFSTests(t *testing.T, tester Tester) {
+ for name, testFn := range localFSTests {
+ t.Run(name, func(t *testing.T) {
+ runServerClient(t, tester, testFn)
+ })
+ }
+}
+
+type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD)
+
+var localFSTests map[string]testFunc = map[string]testFunc{
+ "Stat": testStat,
+ "RegularFileIO": testRegularFileIO,
+ "RegularFileOpen": testRegularFileOpen,
+ "SetStat": testSetStat,
+ "Allocate": testAllocate,
+ "StatFS": testStatFS,
+ "Unlink": testUnlink,
+ "Symlink": testSymlink,
+ "HardLink": testHardLink,
+ "Walk": testWalk,
+ "Rename": testRename,
+ "Mknod": testMknod,
+ "Getdents": testGetdents,
+}
+
+func runServerClient(t *testing.T, tester Tester, testFn testFunc) {
+ mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "")
+ if err != nil {
+ t.Fatalf("creation of temporary mountpoint failed: %v", err)
+ }
+ defer os.RemoveAll(mountPath)
+
+ // fsgofer should run with a umask of 0, because we want to preserve file
+ // modes exactly for testing purposes.
+ unix.Umask(0)
+
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ server := tester.NewServer(t)
+ conn, err := server.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ server.StartConnection(conn)
+
+ c, root, err := lisafs.NewClient(clientSocket, mountPath)
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ if !root.ControlFD.Ok() {
+ t.Fatalf("root control FD is not valid")
+ }
+ rootFile := c.NewFD(root.ControlFD)
+
+ ctx := context.Background()
+ testFn(ctx, t, tester, rootFile)
+ closeFD(ctx, t, rootFile)
+
+ c.Close() // This should trigger client and server shutdown.
+ server.Wait()
+}
+
+func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) {
+ if err := fdLisa.Close(ctx); err != nil {
+ t.Errorf("failed to close FD: %v", err)
+ }
+}
+
+func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) {
+ if err := fdLisa.StatTo(ctx, stat); err != nil {
+ t.Fatalf("stat failed: %v", err)
+ }
+}
+
+func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) {
+ child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("OpenCreateAt failed: %v", err)
+ }
+ if childHostFD == -1 {
+ t.Error("no host FD donated")
+ }
+ client := fdLisa.Client()
+ return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD
+}
+
+func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) {
+ newFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ t.Fatalf("OpenAt failed: %v", err)
+ }
+ if hostFD == -1 && isReg {
+ t.Error("no host FD donated")
+ }
+ return fdLisa.Client().NewFD(newFD), hostFD
+}
+
+func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) {
+ var flags uint32
+ if isDir {
+ flags = unix.AT_REMOVEDIR
+ }
+ if err := dir.UnlinkAt(ctx, name, flags); err != nil {
+ t.Errorf("unlinking file %s failed: %v", name, err)
+ }
+}
+
+func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("symlink failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.LinkAt(ctx, target.ID(), name)
+ if err != nil {
+ t.Fatalf("link failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("mkdir failed: %v", err)
+ }
+ return dir.Client().NewFD(childIno.ControlFD), childIno.Stat
+}
+
+func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0)
+ if err != nil {
+ t.Fatalf("mknod failed: %v", err)
+ }
+ return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat
+}
+
+func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode {
+ _, inodes, err := dir.WalkMultiple(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return inodes
+}
+
+func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx {
+ stats, err := dir.WalkStat(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return stats
+}
+
+func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error {
+ count, err := fdLisa.Write(ctx, buf, off)
+ if err != nil {
+ return err
+ }
+ if int(count) != len(buf) {
+ t.Errorf("partial write: buf size = %d, written = %d", len(buf), count)
+ }
+ return nil
+}
+
+func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) {
+ buf := make([]byte, len(want))
+ n, err := fdLisa.Read(ctx, buf, off)
+ if err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if int(n) != len(want) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(want), n)
+ return
+ }
+ if bytes.Compare(buf, want) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf)
+ }
+}
+
+func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) {
+ if err := fdLisa.Allocate(ctx, 0, off, length); err != nil {
+ t.Fatalf("fallocate failed: %v", err)
+ }
+
+ var stat linux.Statx
+ statTo(ctx, t, fdLisa, &stat)
+ if want := off + length; stat.Size != want {
+ t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size)
+ }
+}
+
+func cmpStatx(t *testing.T, want, got linux.Statx) {
+ if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 {
+ if got.Mode != want.Mode {
+ t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode)
+ }
+ }
+ if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 {
+ if got.Ino != want.Ino {
+ t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino)
+ }
+ }
+ if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 {
+ if got.Nlink != want.Nlink {
+ t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink)
+ }
+ }
+ if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 {
+ if got.UID != want.UID {
+ t.Errorf("UID differs: want %d, got %d", want.UID, got.UID)
+ }
+ }
+ if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 {
+ if got.GID != want.GID {
+ t.Errorf("GID differs: want %d, got %d", want.GID, got.GID)
+ }
+ }
+ if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 {
+ if got.Size != want.Size {
+ t.Errorf("size differs: want %d, got %d", want.Size, got.Size)
+ }
+ }
+ if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 {
+ if got.Blocks != want.Blocks {
+ t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks)
+ }
+ }
+ if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 {
+ if got.Atime != want.Atime {
+ t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime)
+ }
+ }
+ if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 {
+ if got.Mtime != want.Mtime {
+ t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime)
+ }
+ }
+ if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 {
+ if got.Ctime != want.Ctime {
+ t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime)
+ }
+ }
+}
+
+func hasCapability(c capability.Cap) bool {
+ caps, err := capability.NewPid2(os.Getpid())
+ if err != nil {
+ return false
+ }
+ if err := caps.Load(); err != nil {
+ return false
+ }
+ return caps.Get(capability.EFFECTIVE, c)
+}
+
+func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var rootStat linux.Statx
+ if err := root.StatTo(ctx, &rootStat); err != nil {
+ t.Errorf("stat on the root dir failed: %v", err)
+ }
+
+ if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR {
+ t.Errorf("root inode is not a directory, file type = %d", ftype)
+ }
+}
+
+func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Test Read/Write RPCs.
+ data := make([]byte, 100)
+ rand.Read(data)
+ if err := writeFD(ctx, t, fd, 0, data); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ readFDAndCmp(ctx, t, fd, 0, data)
+ readFDAndCmp(ctx, t, fd, 50, data[50:])
+
+ // Make sure the host FD is configured properly.
+ hostReadData := make([]byte, len(data))
+ if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil {
+ t.Errorf("host read failed: %v", err)
+ } else if n != len(hostReadData) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n)
+ } else if bytes.Compare(hostReadData, data) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData)
+ }
+
+ // Test syncing the writable FD.
+ if err := fd.Sync(ctx); err != nil {
+ t.Errorf("syncing the FD failed: %v", err)
+ }
+}
+
+func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Open a readonly FD and try writing to it to get an EBADF.
+ roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */)
+ defer closeFD(ctx, t, roFile)
+ defer unix.Close(roHostFD)
+ if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF {
+ t.Errorf("writing to read only FD should generate EBADF, but got %v", err)
+ }
+}
+
+func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ now := time.Now()
+ wantStat := linux.Statx{
+ Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE,
+ Mode: 0760,
+ UID: uint32(unix.Getuid()),
+ GID: uint32(unix.Getgid()),
+ Size: 50,
+ Atime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ Mtime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ }
+ if tester.SetUserGroupIDSupported() {
+ wantStat.Mask |= unix.STATX_UID | unix.STATX_GID
+ }
+ failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat)
+ if err != nil {
+ t.Fatalf("setstat failed: %v", err)
+ }
+ if failureMask != 0 {
+ t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr)
+ }
+
+ // Verify that attributes were updated.
+ var gotStat linux.Statx
+ statTo(ctx, t, controlFile, &gotStat)
+ if gotStat.Mode&07777 != wantStat.Mode ||
+ gotStat.Size != wantStat.Size ||
+ gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() ||
+ gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() ||
+ (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) {
+ t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat)
+ }
+}
+
+func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ allocateAndVerify(ctx, t, fd, 0, 40)
+ allocateAndVerify(ctx, t, fd, 20, 100)
+}
+
+func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var statFS lisafs.StatFS
+ if err := root.StatFSTo(ctx, &statFS); err != nil {
+ t.Errorf("statfs failed: %v", err)
+ }
+}
+
+func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ unlinkFile(ctx, t, root, name, false /* isDir */)
+ if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 {
+ t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes)
+ }
+}
+
+func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ target := "/tmp/some/path"
+ name := "symlinkFile"
+ link, linkStat := symlink(ctx, t, root, name, target)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK {
+ t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode)
+ }
+
+ if gotTarget, err := link.ReadLinkAt(ctx); err != nil {
+ t.Fatalf("readlink failed: %v", err)
+ } else if gotTarget != target {
+ t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget)
+ }
+}
+
+func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ if !tester.LinkSupported() {
+ t.Skipf("server does not support LinkAt RPC")
+ }
+ if !hasCapability(capability.CAP_DAC_READ_SEARCH) {
+ t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid())
+ }
+ name := "tempFile"
+ controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ link, linkStat := link(ctx, t, root, name, controlFile)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Ino != fileIno.Ino {
+ t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino)
+ }
+ if linkStat.DevMinor != fileIno.DevMinor {
+ t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor)
+ }
+ if linkStat.DevMajor != fileIno.DevMajor {
+ t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor)
+ }
+}
+
+func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ // Create 10 nested directories.
+ n := 10
+ curDir := root
+
+ dirNames := make([]string, 0, n)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("tmpdir-%d", i)
+ childDir, _ := mkdir(ctx, t, curDir, name)
+ defer closeFD(ctx, t, childDir)
+ defer unlinkFile(ctx, t, curDir, name, true /* isDir */)
+
+ curDir = childDir
+ dirNames = append(dirNames, name)
+ }
+
+ // Walk all these directories. Add some junk at the end which should not be
+ // walked on.
+ dirNames = append(dirNames, []string{"a", "b", "c"}...)
+ inodes := walk(ctx, t, root, dirNames)
+ if len(inodes) != n {
+ t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes))
+ }
+
+ // Close all control FDs and collect stat results for all dirs including
+ // the root directory.
+ dirStats := make([]linux.Statx, 0, n+1)
+ var stat linux.Statx
+ statTo(ctx, t, root, &stat)
+ dirStats = append(dirStats, stat)
+ for _, inode := range inodes {
+ dirStats = append(dirStats, inode.Stat)
+ closeFD(ctx, t, root.Client().NewFD(inode.ControlFD))
+ }
+
+ // Test WalkStat which additonally returns Statx for root because the first
+ // path component is "".
+ dirNames = append([]string{""}, dirNames...)
+ gotStats := walkStat(ctx, t, root, dirNames)
+ if len(gotStats) != len(dirStats) {
+ t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats))
+ } else {
+ for i := range gotStats {
+ cmpStatx(t, dirStats[i], gotStats[i])
+ }
+ }
+}
+
+func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, tempFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+
+ // Move tempFile into tempDir.
+ if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil {
+ t.Fatalf("rename failed: %v", err)
+ }
+
+ inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"})
+ if len(inodes) != 2 {
+ t.Errorf("expected 2 files on walk but only found %d", len(inodes))
+ }
+}
+
+func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "namedPipe"
+ pipeFile, pipeStat := mknod(ctx, t, root, name)
+ defer closeFD(ctx, t, pipeFile)
+
+ var stat linux.Statx
+ statTo(ctx, t, pipeFile, &stat)
+
+ if stat.Mode != pipeStat.Mode {
+ t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode)
+ }
+ if stat.UID != pipeStat.UID {
+ t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID)
+ }
+ if stat.GID != pipeStat.GID {
+ t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID)
+ }
+}
+
+func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+ defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */)
+
+ // Create 10 files in tempDir.
+ n := 10
+ fileStats := make(map[string]linux.Statx)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("file-%d", i)
+ newFile, fileStat := mknod(ctx, t, tempDir, name)
+ defer closeFD(ctx, t, newFile)
+ defer unlinkFile(ctx, t, tempDir, name, false /* isDir */)
+
+ fileStats[name] = fileStat
+ }
+
+ // Use opened directory FD for getdents.
+ openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */)
+ defer closeFD(ctx, t, openDirFile)
+
+ dirents := make([]lisafs.Dirent64, 0, n)
+ for i := 0; i < n+2; i++ {
+ gotDirents, err := openDirFile.Getdents64(ctx, 40)
+ if err != nil {
+ t.Fatalf("getdents failed: %v", err)
+ }
+ if len(gotDirents) == 0 {
+ break
+ }
+ for _, dirent := range gotDirents {
+ if dirent.Name != "." && dirent.Name != ".." {
+ dirents = append(dirents, dirent)
+ }
+ }
+ }
+
+ if len(dirents) != n {
+ t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents))
+ }
+ for _, dirent := range dirents {
+ stat, ok := fileStats[string(dirent.Name)]
+ if !ok {
+ t.Errorf("received a dirent that was not created: %+v", dirent)
+ continue
+ }
+
+ if dirent.Type != unix.DT_REG {
+ t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type)
+ }
+ if uint64(dirent.Ino) != stat.Ino {
+ t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino)
+ }
+ if uint32(dirent.DevMinor) != stat.DevMinor {
+ t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor)
+ }
+ if uint32(dirent.DevMajor) != stat.DevMajor {
+ t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor)
+ }
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index eb496f02f..d618da820 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -115,7 +115,7 @@ type Client struct {
// channels is the set of all initialized channels.
channels []*channel
- // availableChannels is a FIFO of inactive channels.
+ // availableChannels is a LIFO of inactive channels.
availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4244f2cf5..509dd0e1a 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -54,7 +54,10 @@ go_library(
"//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/hostarch",
+ "//pkg/lisafs",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/p9",
"//pkg/refs",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5c48a9fee..d99a6112c 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -222,47 +222,88 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
off := uint64(0)
const count = 64 * 1024 // for consistency with the vfs1 client
d.handleMu.RLock()
- if d.readFile.isNil() {
+ if !d.isReadFileOk() {
// This should not be possible because a readable handle should
// have been opened when the calling directoryFD was opened.
d.handleMu.RUnlock()
panic("gofer.dentry.getDirents called without a readable handle")
}
+ // shouldSeek0 indicates whether the server should SEEK to 0 before reading
+ // directory entries.
+ shouldSeek0 := true
for {
- p9ds, err := d.readFile.readdir(ctx, off, count)
- if err != nil {
- d.handleMu.RUnlock()
- return nil, err
- }
- if len(p9ds) == 0 {
- d.handleMu.RUnlock()
- break
- }
- for _, p9d := range p9ds {
- if p9d.Name == "." || p9d.Name == ".." {
- continue
+ if d.fs.opts.lisaEnabled {
+ countLisa := int32(count)
+ if shouldSeek0 {
+ // See lisafs.Getdents64Req.Count.
+ countLisa = -countLisa
+ shouldSeek0 = false
+ }
+ lisafsDs, err := d.readFDLisa.Getdents64(ctx, countLisa)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(lisafsDs) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for i := range lisafsDs {
+ name := string(lisafsDs[i].Name)
+ if name == "." || name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: name,
+ Ino: d.fs.inoFromKey(inoKey{
+ ino: uint64(lisafsDs[i].Ino),
+ devMinor: uint32(lisafsDs[i].DevMinor),
+ devMajor: uint32(lisafsDs[i].DevMajor),
+ }),
+ NextOff: int64(len(dirents) + 1),
+ Type: uint8(lisafsDs[i].Type),
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[name] = struct{}{}
+ }
}
- dirent := vfs.Dirent{
- Name: p9d.Name,
- Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
- NextOff: int64(len(dirents) + 1),
+ } else {
+ p9ds, err := d.readFile.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
}
- // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
- // DMSOCKET.
- switch p9d.Type {
- case p9.TypeSymlink:
- dirent.Type = linux.DT_LNK
- case p9.TypeDir:
- dirent.Type = linux.DT_DIR
- default:
- dirent.Type = linux.DT_REG
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
}
- dirents = append(dirents, dirent)
- if realChildren != nil {
- realChildren[p9d.Name] = struct{}{}
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
}
+ off = p9ds[len(p9ds)-1].Offset
}
- off = p9ds[len(p9ds)-1].Offset
}
}
// Emit entries for synthetic children.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 00228c469..f7b3446d3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -21,10 +21,12 @@ import (
"sync"
"sync/atomic"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -53,9 +55,47 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
+ if fs.opts.lisaEnabled {
+ // Try accumulating all FDIDs to fsync and fsync then via one RPC as
+ // opposed to making an RPC per FDID. Passing a non-nil accFsyncFDIDs to
+ // dentry.syncCachedFile() and specialFileFD.sync() will cause them to not
+ // make an RPC, instead accumulate syncable FDIDs in the passed slice.
+ accFsyncFDIDs := make([]lisafs.FDID, 0, len(ds)+len(sffds))
+
+ // Sync syncable dentries.
+ for _, d := range ds {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ if err := fs.clientLisa.SyncFDs(ctx, accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: fs.fsyncMultipleFDLisa failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+ }
+
// Sync syncable dentries.
for _, d := range ds {
- if err := d.syncCachedFile(ctx, true /* forFilesystemSync */); err != nil {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
if retErr == nil {
retErr = err
@@ -66,7 +106,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- if err := sffd.sync(ctx, true /* forFilesystemSync */); err != nil {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
if retErr == nil {
retErr = err
@@ -197,7 +237,13 @@ afterSymlink:
rp.Advance()
return d.parent, followedSymlink, nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ var child *dentry
+ var err error
+ if fs.opts.lisaEnabled {
+ child, err = fs.getChildAndWalkPathLocked(ctx, d, rp, ds)
+ } else {
+ child, err = fs.getChildLocked(ctx, d, name, ds)
+ }
if err != nil {
return nil, false, err
}
@@ -219,6 +265,99 @@ afterSymlink:
return child, followedSymlink, nil
}
+// Preconditions:
+// * fs.opts.lisaEnabled.
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
+// * parent.isDir().
+// * parent and the dentry at name have been revalidated.
+func (fs *filesystem) getChildAndWalkPathLocked(ctx context.Context, parent *dentry, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ // Note that pit is a copy of the iterator that does not affect rp.
+ pit := rp.Pit()
+ first := pit.String()
+ if len(first) > maxFilenameLen {
+ return nil, linuxerr.ENAMETOOLONG
+ }
+ if child, ok := parent.children[first]; ok || parent.isSynthetic() {
+ if child == nil {
+ return nil, linuxerr.ENOENT
+ }
+ return child, nil
+ }
+
+ // Walk as much of the path as possible in 1 RPC.
+ names := []string{first}
+ for pit = pit.Next(); pit.Ok(); pit = pit.Next() {
+ name := pit.String()
+ if name == "." {
+ continue
+ }
+ if name == ".." {
+ break
+ }
+ names = append(names, name)
+ }
+ status, inodes, err := parent.controlFDLisa.WalkMultiple(ctx, names)
+ if err != nil {
+ return nil, err
+ }
+ if len(inodes) == 0 {
+ parent.cacheNegativeLookupLocked(first)
+ return nil, linuxerr.ENOENT
+ }
+
+ // Add the walked inodes into the dentry tree.
+ curParent := parent
+ curParentDirMuLock := func() {
+ if curParent != parent {
+ curParent.dirMu.Lock()
+ }
+ }
+ curParentDirMuUnlock := func() {
+ if curParent != parent {
+ curParent.dirMu.Unlock() // +checklocksforce: locked via curParentDirMuLock().
+ }
+ }
+ var ret *dentry
+ var dentryCreationErr error
+ for i := range inodes {
+ if dentryCreationErr != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ continue
+ }
+
+ child, err := fs.newDentryLisa(ctx, &inodes[i])
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ dentryCreationErr = err
+ continue
+ }
+ curParentDirMuLock()
+ curParent.cacheNewChildLocked(child, names[i])
+ curParentDirMuUnlock()
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked(). curParent gained a ref so we should also
+ // call curParent.checkCachingLocked() so it can be removed from the cache
+ // if needed. We only do that for the first iteration because all
+ // subsequent parents would have already been added to ds.
+ if i == 0 {
+ *ds = appendDentry(*ds, curParent)
+ }
+ *ds = appendDentry(*ds, child)
+ curParent = child
+ if i == 0 {
+ ret = child
+ }
+ }
+
+ if status == lisafs.WalkComponentDoesNotExist && curParent.isDir() {
+ curParentDirMuLock()
+ curParent.cacheNegativeLookupLocked(names[len(inodes)])
+ curParentDirMuUnlock()
+ }
+ return ret, dentryCreationErr
+}
+
// getChildLocked returns a dentry representing the child of parent with the
// given name. Returns ENOENT if the child doesn't exist.
//
@@ -227,7 +366,7 @@ afterSymlink:
// * parent.dirMu must be locked.
// * parent.isDir().
// * name is not "." or "..".
-// * dentry at name has been revalidated
+// * parent and the dentry at name have been revalidated.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if len(name) > maxFilenameLen {
return nil, linuxerr.ENAMETOOLONG
@@ -239,20 +378,35 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
return child, nil
}
- qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
- if err != nil {
- if linuxerr.Equals(linuxerr.ENOENT, err) {
- parent.cacheNegativeLookupLocked(name)
+ var child *dentry
+ if fs.opts.lisaEnabled {
+ childInode, err := parent.controlFDLisa.Walk(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentryLisa(ctx, childInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, childInode.ControlFD)
+ return nil, err
+ }
+ } else {
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ return nil, err
}
- return nil, err
- }
-
- // Create a new dentry representing the file.
- child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
- if err != nil {
- file.close(ctx)
- delete(parent.children, name)
- return nil, err
}
parent.cacheNewChildLocked(child, name)
appendNewChildDentry(ds, parent, child)
@@ -328,7 +482,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
// Preconditions:
// * !rp.Done().
// * For the final path component in rp, !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error), createInSyntheticDir func(parent *dentry, name string) error, updateChild func(child *dentry)) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
@@ -415,9 +569,26 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
// No cached dentry exists; however, in InteropModeShared there might still be
// an existing file at name. Just attempt the file creation RPC anyways. If a
// file does exist, the RPC will fail with EEXIST like we would have.
- if err := createInRemoteDir(parent, name, &ds); err != nil {
+ lisaInode, err := createInRemoteDir(parent, name, &ds)
+ if err != nil {
return err
}
+ // lisafs may aggresively cache newly created inodes. This has helped reduce
+ // Walk RPCs in practice.
+ if lisaInode != nil {
+ child, err := fs.newDentryLisa(ctx, lisaInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, lisaInode.ControlFD)
+ return err
+ }
+ parent.cacheNewChildLocked(child, name)
+ appendNewChildDentry(&ds, parent, child)
+
+ // lisafs may update dentry properties upon successful creation.
+ if updateChild != nil {
+ updateChild(child)
+ }
+ }
if fs.opts.interop != InteropModeShared {
if child, ok := parent.children[name]; ok && child == nil {
// Delete the now-stale negative dentry.
@@ -565,7 +736,11 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return linuxerr.ENOENT
}
} else if child == nil || !child.isSynthetic() {
- err = parent.file.unlinkAt(ctx, name, flags)
+ if fs.opts.lisaEnabled {
+ err = parent.controlFDLisa.UnlinkAt(ctx, name, flags)
+ } else {
+ err = parent.file.unlinkAt(ctx, name, flags)
+ }
if err != nil {
if child != nil {
vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
@@ -658,40 +833,43 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error {
+ err := fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, ds **[]*dentry) (*lisafs.Inode, error) {
if rp.Mount() != vd.Mount() {
- return linuxerr.EXDEV
+ return nil, linuxerr.EXDEV
}
d := vd.Dentry().Impl().(*dentry)
if d.isDir() {
- return linuxerr.EPERM
+ return nil, linuxerr.EPERM
}
gid := auth.KGID(atomic.LoadUint32(&d.gid))
uid := auth.KUID(atomic.LoadUint32(&d.uid))
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil {
- return err
+ return nil, err
}
if d.nlink == 0 {
- return linuxerr.ENOENT
+ return nil, linuxerr.ENOENT
}
if d.nlink == math.MaxUint32 {
- return linuxerr.EMLINK
+ return nil, linuxerr.EMLINK
}
- if err := parent.file.link(ctx, d.file, childName); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.LinkAt(ctx, d.controlFDLisa.ID(), childName)
}
+ return nil, parent.file.link(ctx, d.file, childName)
+ }, nil, nil)
+ if err == nil {
// Success!
- atomic.AddUint32(&d.nlink, 1)
- return nil
- }, nil)
+ vd.Dentry().Impl().(*dentry).incLinks()
+ }
+ return err
}
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
// If the parent is a setgid directory, use the parent's GID
// rather than the caller's and enable setgid.
kgid := creds.EffectiveKGID
@@ -700,9 +878,18 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid = auth.KGID(atomic.LoadUint32(&parent.gid))
mode |= linux.S_ISGID
}
- if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
+ var (
+ childDirInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childDirInode, err = parent.controlFDLisa.MkdirAt(ctx, name, mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ } else {
+ _, err = parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ }
+ if err != nil {
if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
- return err
+ return nil, err
}
ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
parent.createSyntheticChildLocked(&createSyntheticOpts{
@@ -716,7 +903,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if fs.opts.interop != InteropModeShared {
parent.incLinks()
}
- return nil
+ return childDirInode, nil
}, func(parent *dentry, name string) error {
if !opts.ForSyntheticMountpoint {
// Can't create non-synthetic files in synthetic directories.
@@ -730,16 +917,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
})
parent.incLinks()
return nil
- })
+ }, nil)
}
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
- _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- if !linuxerr.Equals(linuxerr.EPERM, err) {
- return err
+ var (
+ childInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childInode, err = parent.controlFDLisa.MknodAt(ctx, name, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID), opts.DevMinor, opts.DevMajor)
+ } else {
+ _, err = parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ }
+ if err == nil {
+ return childInode, nil
+ } else if !linuxerr.Equals(linuxerr.EPERM, err) {
+ return nil, err
}
// EPERM means that gofer does not allow creating a socket or pipe. Fallback
@@ -750,10 +947,10 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
switch {
case err == nil:
// Step succeeded, another file exists.
- return linuxerr.EEXIST
+ return nil, linuxerr.EEXIST
case !linuxerr.Equals(linuxerr.ENOENT, err):
// Unexpected error.
- return err
+ return nil, err
}
switch opts.Mode.FileType() {
@@ -766,7 +963,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
endpoint: opts.Endpoint,
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
case linux.S_IFIFO:
parent.createSyntheticChildLocked(&createSyntheticOpts{
name: name,
@@ -776,11 +973,11 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
}
// Retain error from gofer if synthetic file cannot be created internally.
- return linuxerr.EPERM
- }, nil)
+ return nil, linuxerr.EPERM
+ }, nil, nil)
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -986,6 +1183,23 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
if opts.Flags&linux.O_DIRECT != 0 {
return nil, linuxerr.EINVAL
}
+ if d.fs.opts.lisaEnabled {
+ // Note that special value of linux.SockType = 0 is interpreted by lisafs
+ // as "do not care about the socket type". Analogous to p9.AnonymousSocket.
+ sockFD, err := d.controlFDLisa.Connect(ctx, 0 /* sockType */)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), sockFD, &host.NewFDOptions{
+ HaveFlags: true,
+ Flags: opts.Flags,
+ })
+ if err != nil {
+ unix.Close(sockFD)
+ return nil, err
+ }
+ return fd, nil
+ }
fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
if err != nil {
return nil, err
@@ -998,6 +1212,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
fdObj.Close()
return nil, err
}
+ // Ownership has been transferred to fd.
fdObj.Release()
return fd, nil
}
@@ -1017,7 +1232,13 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.
// since closed its end.
isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
retry:
- h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ } else {
+ h, err = openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ }
if err != nil {
if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) {
// An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
@@ -1061,18 +1282,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
defer mnt.EndWrite()
- // 9P2000.L's lcreate takes a fid representing the parent directory, and
- // converts it into an open fid representing the created file, so we need
- // to duplicate the directory fid first.
- _, dirfile, err := d.file.walk(ctx, nil)
- if err != nil {
- return nil, err
- }
creds := rp.Credentials()
name := rp.Component()
- // We only want the access mode for creating the file.
- createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
-
// If the parent is a setgid directory, use the parent's GID rather
// than the caller's.
kgid := creds.EffectiveKGID
@@ -1080,51 +1291,87 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
kgid = auth.KGID(atomic.LoadUint32(&d.gid))
}
- fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
- if err != nil {
- dirfile.close(ctx)
- return nil, err
- }
- // Then we need to walk to the file we just created to get a non-open fid
- // representing it, and to get its metadata. This must use d.file since, as
- // explained above, dirfile was invalidated by dirfile.Create().
- _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
- if err != nil {
- openFile.close(ctx)
- if fdobj != nil {
- fdobj.Close()
+ var child *dentry
+ var openP9File p9file
+ openLisaFD := lisafs.InvalidFDID
+ openHostFD := int32(-1)
+ if d.fs.opts.lisaEnabled {
+ ino, openFD, hostFD, err := d.controlFDLisa.OpenCreateAt(ctx, name, opts.Flags&linux.O_ACCMODE, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ if err != nil {
+ return nil, err
+ }
+ openHostFD = int32(hostFD)
+ openLisaFD = openFD
+
+ child, err = d.fs.newDentryLisa(ctx, &ino)
+ if err != nil {
+ d.fs.clientLisa.CloseFDBatched(ctx, ino.ControlFD)
+ d.fs.clientLisa.CloseFDBatched(ctx, openFD)
+ if hostFD >= 0 {
+ unix.Close(hostFD)
+ }
+ return nil, err
+ }
+ } else {
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ // We only want the access mode for creating the file.
+ createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
+
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+
+ // Construct the new dentry.
+ child, err = d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
}
- return nil, err
- }
- // Construct the new dentry.
- child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
- if err != nil {
- nonOpenFile.close(ctx)
- openFile.close(ctx)
if fdobj != nil {
- fdobj.Close()
+ openHostFD = int32(fdobj.Release())
}
- return nil, err
+ openP9File = openFile
}
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
- openFD := int32(-1)
- if fdobj != nil {
- openFD = int32(fdobj.Release())
- }
child.handleMu.Lock()
if vfs.MayReadFileWithOpenFlags(opts.Flags) {
- child.readFile = openFile
- if fdobj != nil {
- child.readFD = openFD
- child.mmapFD = openFD
+ child.readFile = openP9File
+ child.readFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ if openHostFD != -1 {
+ child.readFD = openHostFD
+ child.mmapFD = openHostFD
}
}
if vfs.MayWriteFileWithOpenFlags(opts.Flags) {
- child.writeFile = openFile
- child.writeFD = openFD
+ child.writeFile = openP9File
+ child.writeFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ child.writeFD = openHostFD
}
child.handleMu.Unlock()
}
@@ -1146,11 +1393,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
childVFSFD = &fd.vfsfd
} else {
h := handle{
- file: openFile,
- fd: -1,
- }
- if fdobj != nil {
- h.fd = int32(fdobj.Release())
+ file: openP9File,
+ fdLisa: d.fs.clientLisa.NewFD(openLisaFD),
+ fd: openHostFD,
}
fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
if err != nil {
@@ -1304,7 +1549,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Update the remote filesystem.
if !renamed.isSynthetic() {
- if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ if fs.opts.lisaEnabled {
+ err = renamed.controlFDLisa.RenameTo(ctx, newParent.controlFDLisa.ID(), newName)
+ } else {
+ err = renamed.file.rename(ctx, newParent.file, newName)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1315,7 +1565,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if replaced.isDir() {
flags = linux.AT_REMOVEDIR
}
- if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ if fs.opts.lisaEnabled {
+ err = newParent.controlFDLisa.UnlinkAt(ctx, newName, flags)
+ } else {
+ err = newParent.file.unlinkAt(ctx, newName, flags)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1431,6 +1686,28 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
for d.isSynthetic() {
d = d.parent
}
+ if fs.opts.lisaEnabled {
+ var statFS lisafs.StatFS
+ if err := d.controlFDLisa.StatFSTo(ctx, &statFS); err != nil {
+ return linux.Statfs{}, err
+ }
+ if statFS.NameLength > maxFilenameLen {
+ statFS.NameLength = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: statFS.BlockSize,
+ Blocks: statFS.Blocks,
+ BlocksFree: statFS.BlocksFree,
+ BlocksAvailable: statFS.BlocksAvailable,
+ Files: statFS.Files,
+ FilesFree: statFS.FilesFree,
+ NameLength: statFS.NameLength,
+ }, nil
+ }
fsstat, err := d.file.statFS(ctx)
if err != nil {
return linux.Statfs{}, err
@@ -1456,11 +1733,21 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.SymlinkAt(ctx, name, target, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID))
+ }
_, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- return err
- }, nil)
+ return nil, err
+ }, nil, func(child *dentry) {
+ if fs.opts.interop != InteropModeShared {
+ // lisafs caches the symlink target on creation. In practice, this
+ // helps avoid a lot of ReadLink RPCs.
+ child.haveTarget = true
+ child.target = target
+ }
+ })
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
@@ -1505,7 +1792,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.listXattr(ctx, rp.Credentials(), size)
+ return d.listXattr(ctx, size)
}
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
@@ -1612,6 +1899,9 @@ func (fs *filesystem) MountOptions() string {
if fs.opts.overlayfsStaleRead {
optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil})
}
+ if fs.opts.lisaEnabled {
+ optsKV = append(optsKV, mopt{moptLisafs, nil})
+ }
opts := make([]string, 0, len(optsKV))
for _, opt := range optsKV {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 43440ec19..7bef8242f 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -48,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
@@ -83,6 +84,7 @@ const (
moptForcePageCache = "force_page_cache"
moptLimitHostFDTranslation = "limit_host_fd_translation"
moptOverlayfsStaleRead = "overlayfs_stale_read"
+ moptLisafs = "lisafs"
)
// Valid values for the "cache" mount option.
@@ -118,6 +120,10 @@ type filesystem struct {
// client is the client used by this filesystem. client is immutable.
client *p9.Client `state:"nosave"`
+ // clientLisa is the client used for communicating with the server when
+ // lisafs is enabled. lisafsCient is immutable.
+ clientLisa *lisafs.Client `state:"nosave"`
+
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -161,6 +167,12 @@ type filesystem struct {
inoMu sync.Mutex `state:"nosave"`
inoByQIDPath map[uint64]uint64 `state:"nosave"`
+ // inoByKey is the same as inoByQIDPath but only used by lisafs. It helps
+ // identify inodes based on the device ID and host inode number provided
+ // by the gofer process. It is not preserved across checkpoint/restore for
+ // the same reason as above. inoByKey is protected by inoMu.
+ inoByKey map[inoKey]uint64 `state:"nosave"`
+
// lastIno is the last inode number assigned to a file. lastIno is accessed
// using atomic memory operations.
lastIno uint64
@@ -214,6 +226,10 @@ type filesystemOptions struct {
// way that application FDs representing "special files" such as sockets
// do. Note that this disables client caching and mmap for regular files.
regularFilesUseSpecialFileFD bool
+
+ // lisaEnabled indicates whether the client will use lisafs protocol to
+ // communicate with the server instead of 9P.
+ lisaEnabled bool
}
// InteropMode controls the client's interaction with other remote filesystem
@@ -427,6 +443,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, moptOverlayfsStaleRead)
fsopts.overlayfsStaleRead = true
}
+ if lisafs, ok := mopts[moptLisafs]; ok {
+ delete(mopts, moptLisafs)
+ fsopts.lisaEnabled, err = strconv.ParseBool(lisafs)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid lisafs option: %s", lisafs)
+ return nil, nil, linuxerr.EINVAL
+ }
+ }
// fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
// "cache=none".
@@ -458,44 +482,83 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
syncableDentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
+ if err := fs.initClientAndRoot(ctx); err != nil {
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+
+ return &fs.vfsfs, &fs.root.vfsd, nil
+}
+
+func (fs *filesystem) initClientAndRoot(ctx context.Context) error {
+ var err error
+ if fs.opts.lisaEnabled {
+ var rootInode *lisafs.Inode
+ rootInode, err = fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ fs.root, err = fs.newDentryLisa(ctx, rootInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, rootInode.ControlFD)
+ }
+ } else {
+ fs.root, err = fs.initClient(ctx)
+ }
+
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
+ if err == nil {
+ fs.root.refs = 2
+ }
+ return err
+}
+
+func (fs *filesystem) initClientLisa(ctx context.Context) (*lisafs.Inode, error) {
+ sock, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return nil, err
+ }
+
+ var rootInode *lisafs.Inode
+ ctx.UninterruptibleSleepStart(false)
+ fs.clientLisa, rootInode, err = lisafs.NewClient(sock, fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ return rootInode, err
+}
+
+func (fs *filesystem) initClient(ctx context.Context) (*dentry, error) {
// Connect to the server.
if err := fs.dial(ctx); err != nil {
- return nil, nil, err
+ return nil, err
}
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fs.opts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
- // Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is held by fs to prevent the root from being "cached"
- // and subsequently evicted.
- root.refs = 2
- fs.root = root
-
- return &fs.vfsfs, &root.vfsd, nil
+ return root, nil
}
func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
@@ -613,7 +676,11 @@ func (fs *filesystem) Release(ctx context.Context) {
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
- fs.client.Close()
+ if fs.opts.lisaEnabled {
+ fs.clientLisa.Close()
+ } else {
+ fs.client.Close()
+ }
}
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
@@ -644,6 +711,23 @@ func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
}
}
+// inoKey is the key used to identify the inode backed by this dentry.
+//
+// +stateify savable
+type inoKey struct {
+ ino uint64
+ devMinor uint32
+ devMajor uint32
+}
+
+func inoKeyFromStat(stat *linux.Statx) inoKey {
+ return inoKey{
+ ino: stat.Ino,
+ devMinor: stat.DevMinor,
+ devMajor: stat.DevMajor,
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -674,6 +758,9 @@ type dentry struct {
// qidPath is the p9.QID.Path for this file. qidPath is immutable.
qidPath uint64
+ // inoKey is used to identify this dentry's inode.
+ inoKey inoKey
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
@@ -681,6 +768,14 @@ type dentry struct {
// only files that can be synthetic are sockets, pipes, and directories.
file p9file `state:"nosave"`
+ // controlFDLisa is used by lisafs to perform path based operations on this
+ // dentry.
+ //
+ // if !controlFDLisa.Ok(), this dentry represents a synthetic file, i.e. a
+ // file that does not exist on the remote filesystem. As of this writing, the
+ // only files that can be synthetic are sockets, pipes, and directories.
+ controlFDLisa lisafs.ClientFD `state:"nosave"`
+
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
deleted uint32
@@ -791,12 +886,14 @@ type dentry struct {
// always either -1 or equal to readFD; if !writeFile.isNil() (the file has
// been opened for writing), it is additionally either -1 or equal to
// writeFD.
- handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"`
- writeFile p9file `state:"nosave"`
- readFD int32 `state:"nosave"`
- writeFD int32 `state:"nosave"`
- mmapFD int32 `state:"nosave"`
+ handleMu sync.RWMutex `state:"nosave"`
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ readFDLisa lisafs.ClientFD `state:"nosave"`
+ writeFDLisa lisafs.ClientFD `state:"nosave"`
+ readFD int32 `state:"nosave"`
+ writeFD int32 `state:"nosave"`
+ mmapFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -920,6 +1017,79 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
return d, nil
}
+func (fs *filesystem) newDentryLisa(ctx context.Context, ino *lisafs.Inode) (*dentry, error) {
+ if ino.Stat.Mask&linux.STATX_TYPE == 0 {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, linuxerr.EIO
+ }
+ if ino.Stat.Mode&linux.FileTypeMask == linux.ModeRegular && ino.Stat.Mask&linux.STATX_SIZE == 0 {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, linuxerr.EIO
+ }
+
+ inoKey := inoKeyFromStat(&ino.Stat)
+ d := &dentry{
+ fs: fs,
+ inoKey: inoKey,
+ ino: fs.inoFromKey(inoKey),
+ mode: uint32(ino.Stat.Mode),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
+ blockSize: hostarch.PageSize,
+ readFD: -1,
+ writeFD: -1,
+ mmapFD: -1,
+ controlFDLisa: fs.clientLisa.NewFD(ino.ControlFD),
+ }
+
+ d.pf.dentry = d
+ if ino.Stat.Mask&linux.STATX_UID != 0 {
+ d.uid = dentryUIDFromLisaUID(lisafs.UID(ino.Stat.UID))
+ }
+ if ino.Stat.Mask&linux.STATX_GID != 0 {
+ d.gid = dentryGIDFromLisaGID(lisafs.GID(ino.Stat.GID))
+ }
+ if ino.Stat.Mask&linux.STATX_SIZE != 0 {
+ d.size = ino.Stat.Size
+ }
+ if ino.Stat.Blksize != 0 {
+ d.blockSize = ino.Stat.Blksize
+ }
+ if ino.Stat.Mask&linux.STATX_ATIME != 0 {
+ d.atime = dentryTimestampFromLisa(ino.Stat.Atime)
+ }
+ if ino.Stat.Mask&linux.STATX_MTIME != 0 {
+ d.mtime = dentryTimestampFromLisa(ino.Stat.Mtime)
+ }
+ if ino.Stat.Mask&linux.STATX_CTIME != 0 {
+ d.ctime = dentryTimestampFromLisa(ino.Stat.Ctime)
+ }
+ if ino.Stat.Mask&linux.STATX_BTIME != 0 {
+ d.btime = dentryTimestampFromLisa(ino.Stat.Btime)
+ }
+ if ino.Stat.Mask&linux.STATX_NLINK != 0 {
+ d.nlink = ino.Stat.Nlink
+ }
+ d.vfsd.Init(d)
+ refsvfs2.Register(d)
+ fs.syncMu.Lock()
+ fs.syncableDentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+func (fs *filesystem) inoFromKey(key inoKey) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+
+ if ino, ok := fs.inoByKey[key]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByKey[key] = ino
+ return ino
+}
+
func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
fs.inoMu.Lock()
defer fs.inoMu.Unlock()
@@ -936,7 +1106,7 @@ func (fs *filesystem) nextIno() uint64 {
}
func (d *dentry) isSynthetic() bool {
- return d.file.isNil()
+ return !d.isControlFileOk()
}
func (d *dentry) cachedMetadataAuthoritative() bool {
@@ -986,6 +1156,50 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
+// updateFromLisaStatLocked is called to update d's metadata after an update
+// from the remote filesystem.
+// Precondition: d.metadataMu must be locked.
+// +checklocks:d.metadataMu
+func (d *dentry) updateFromLisaStatLocked(stat *linux.Statx) {
+ if stat.Mask&linux.STATX_TYPE != 0 {
+ if got, want := stat.Mode&linux.FileTypeMask, d.fileType(); uint32(got) != want {
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, dentryUIDFromLisaUID(lisafs.UID(stat.UID)))
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.uid, dentryGIDFromLisaGID(lisafs.GID(stat.GID)))
+ }
+ if stat.Blksize != 0 {
+ atomic.StoreUint32(&d.blockSize, stat.Blksize)
+ }
+ // Don't override newer client-defined timestamps with old server-defined
+ // ones.
+ if stat.Mask&linux.STATX_ATIME != 0 && atomic.LoadUint32(&d.atimeDirty) == 0 {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromLisa(stat.Atime))
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 && atomic.LoadUint32(&d.mtimeDirty) == 0 {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromLisa(stat.Mtime))
+ }
+ if stat.Mask&linux.STATX_CTIME != 0 {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromLisa(stat.Ctime))
+ }
+ if stat.Mask&linux.STATX_BTIME != 0 {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromLisa(stat.Btime))
+ }
+ if stat.Mask&linux.STATX_NLINK != 0 {
+ atomic.StoreUint32(&d.nlink, stat.Nlink)
+ }
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.updateSizeLocked(stat.Size)
+ }
+}
+
// Preconditions: !d.isSynthetic().
// Preconditions: d.metadataMu is locked.
// +checklocks:d.metadataMu
@@ -995,6 +1209,9 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error {
if d.writeFD < 0 {
d.handleMu.RUnlock()
// Ask the gofer if we don't have a host FD.
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
return d.updateFromGetattrLocked(ctx, p9file{})
}
@@ -1014,6 +1231,9 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// updating stale attributes in d.updateFromP9AttrsLocked().
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
return d.updateFromGetattrLocked(ctx, p9file{})
}
@@ -1021,6 +1241,45 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// * !d.isSynthetic().
// * d.metadataMu is locked.
// +checklocks:d.metadataMu
+func (d *dentry) updateFromStatLisaLocked(ctx context.Context, fdLisa *lisafs.ClientFD) error {
+ handleMuRLocked := false
+ if fdLisa == nil {
+ // Use open FDs in preferenece to the control FD. This may be significantly
+ // more efficient in some implementations. Prefer a writable FD over a
+ // readable one since some filesystem implementations may update a writable
+ // FD's metadata after writes, without making metadata updates immediately
+ // visible to read-only FDs representing the same file.
+ d.handleMu.RLock()
+ switch {
+ case d.writeFDLisa.Ok():
+ fdLisa = &d.writeFDLisa
+ handleMuRLocked = true
+ case d.readFDLisa.Ok():
+ fdLisa = &d.readFDLisa
+ handleMuRLocked = true
+ default:
+ fdLisa = &d.controlFDLisa
+ d.handleMu.RUnlock()
+ }
+ }
+
+ var stat linux.Statx
+ err := fdLisa.StatTo(ctx, &stat)
+ if handleMuRLocked {
+ // handleMu must be released before updateFromLisaStatLocked().
+ d.handleMu.RUnlock() // +checklocksforce: complex case.
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromLisaStatLocked(&stat)
+ return nil
+}
+
+// Preconditions:
+// * !d.isSynthetic().
+// * d.metadataMu is locked.
+// +checklocks:d.metadataMu
func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error {
handleMuRLocked := false
if file.isNil() {
@@ -1160,6 +1419,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
}
}
+ // failureMask indicates which attributes could not be set on the remote
+ // filesystem. p9 returns an error if any of the attributes could not be set
+ // but that leads to inconsistency as the server could have set a few
+ // attributes successfully but a later failure will cause the successful ones
+ // to not be updated in the dentry cache.
+ var failureMask uint32
+ var failureErr error
if !d.isSynthetic() {
if stat.Mask != 0 {
if stat.Mask&linux.STATX_SIZE != 0 {
@@ -1169,35 +1435,50 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// the remote file has been truncated).
d.dataMu.Lock()
}
- if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
- UID: stat.Mask&linux.STATX_UID != 0,
- GID: stat.Mask&linux.STATX_GID != 0,
- Size: stat.Mask&linux.STATX_SIZE != 0,
- ATime: stat.Mask&linux.STATX_ATIME != 0,
- MTime: stat.Mask&linux.STATX_MTIME != 0,
- ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
- MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
- }, p9.SetAttr{
- Permissions: p9.FileMode(stat.Mode),
- UID: p9.UID(stat.UID),
- GID: p9.GID(stat.GID),
- Size: stat.Size,
- ATimeSeconds: uint64(stat.Atime.Sec),
- ATimeNanoSeconds: uint64(stat.Atime.Nsec),
- MTimeSeconds: uint64(stat.Mtime.Sec),
- MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
- }); err != nil {
- if stat.Mask&linux.STATX_SIZE != 0 {
- d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ if d.fs.opts.lisaEnabled {
+ var err error
+ failureMask, failureErr, err = d.controlFDLisa.SetStat(ctx, stat)
+ if err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
}
- return err
}
if stat.Mask&linux.STATX_SIZE != 0 {
- // d.size should be kept up to date, and privatized
- // copy-on-write mappings of truncated pages need to be
- // invalidated, even if InteropModeShared is in effect.
- d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ if failureMask&linux.STATX_SIZE == 0 {
+ // d.size should be kept up to date, and privatized
+ // copy-on-write mappings of truncated pages need to be
+ // invalidated, even if InteropModeShared is in effect.
+ d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ } else {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1208,13 +1489,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 {
+ if stat.Mask&linux.STATX_MODE != 0 && failureMask&linux.STATX_MODE == 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
- if stat.Mask&linux.STATX_UID != 0 {
+ if stat.Mask&linux.STATX_UID != 0 && failureMask&linux.STATX_UID == 0 {
atomic.StoreUint32(&d.uid, stat.UID)
}
- if stat.Mask&linux.STATX_GID != 0 {
+ if stat.Mask&linux.STATX_GID != 0 && failureMask&linux.STATX_GID == 0 {
atomic.StoreUint32(&d.gid, stat.GID)
}
// Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because
@@ -1222,15 +1503,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// stat.Mtime to client-local timestamps above, and if
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
- if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Mask&linux.STATX_ATIME != 0 && failureMask&linux.STATX_ATIME == 0 {
atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
- if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mask&linux.STATX_MTIME != 0 && failureMask&linux.STATX_MTIME == 0 {
atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
+ if failureMask != 0 {
+ // Setting some attribute failed on the remote filesystem.
+ return failureErr
+ }
return nil
}
@@ -1346,6 +1631,20 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 {
return uint32(gid)
}
+func dentryUIDFromLisaUID(uid lisafs.UID) uint32 {
+ if !uid.Ok() {
+ return uint32(auth.OverflowUID)
+ }
+ return uint32(uid)
+}
+
+func dentryGIDFromLisaGID(gid lisafs.GID) uint32 {
+ if !gid.Ok() {
+ return uint32(auth.OverflowGID)
+ }
+ return uint32(gid)
+}
+
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
@@ -1654,15 +1953,24 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.dirty.RemoveAll()
}
d.dataMu.Unlock()
- // Clunk open fids and close open host FDs.
- if !d.readFile.isNil() {
- _ = d.readFile.close(ctx)
- }
- if !d.writeFile.isNil() && d.readFile != d.writeFile {
- _ = d.writeFile.close(ctx)
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() && d.readFDLisa.ID() != d.writeFDLisa.ID() {
+ d.readFDLisa.CloseBatched(ctx)
+ }
+ if d.writeFDLisa.Ok() {
+ d.writeFDLisa.CloseBatched(ctx)
+ }
+ } else {
+ // Clunk open fids and close open host FDs.
+ if !d.readFile.isNil() {
+ _ = d.readFile.close(ctx)
+ }
+ if !d.writeFile.isNil() && d.readFile != d.writeFile {
+ _ = d.writeFile.close(ctx)
+ }
+ d.readFile = p9file{}
+ d.writeFile = p9file{}
}
- d.readFile = p9file{}
- d.writeFile = p9file{}
if d.readFD >= 0 {
_ = unix.Close(int(d.readFD))
}
@@ -1674,7 +1982,7 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.mmapFD = -1
d.handleMu.Unlock()
- if !d.file.isNil() {
+ if d.isControlFileOk() {
// Note that it's possible that d.atimeDirty or d.mtimeDirty are true,
// i.e. client and server timestamps may differ (because e.g. a client
// write was serviced by the page cache, and only written back to the
@@ -1683,10 +1991,16 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// instantiated for the same file would remain coherent. Unfortunately,
// this turns out to be too expensive in many cases, so for now we
// don't do this.
- if err := d.file.close(ctx); err != nil {
- log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+
+ // Close the control FD.
+ if d.fs.opts.lisaEnabled {
+ d.controlFDLisa.CloseBatched(ctx)
+ } else {
+ if err := d.file.close(ctx); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+ }
+ d.file = p9file{}
}
- d.file = p9file{}
// Remove d from the set of syncable dentries.
d.fs.syncMu.Lock()
@@ -1712,10 +2026,38 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() {
+func (d *dentry) isControlFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.Ok()
+ }
+ return !d.file.isNil()
+}
+
+func (d *dentry) isReadFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.readFDLisa.Ok()
+ }
+ return !d.readFile.isNil()
+}
+
+func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) {
+ if !d.isControlFileOk() {
return nil, nil
}
+
+ if d.fs.opts.lisaEnabled {
+ xattrs, err := d.controlFDLisa.ListXattr(ctx, size)
+ if err != nil {
+ return nil, err
+ }
+
+ res := make([]string, 0, len(xattrs))
+ for _, xattr := range xattrs {
+ res = append(res, xattr)
+ }
+ return res, nil
+ }
+
xattrMap, err := d.file.listXattr(ctx, size)
if err != nil {
return nil, err
@@ -1728,32 +2070,41 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui
}
func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return "", linuxerr.ENODATA
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
return "", err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.GetXattr(ctx, opts.Name, opts.Size)
+ }
return d.file.getXattr(ctx, opts.Name, opts.Size)
}
func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.SetXattr(ctx, opts.Name, opts.Value, opts.Flags)
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.RemoveXattr(ctx, name)
+ }
return d.file.removeXattr(ctx, name)
}
@@ -1765,19 +2116,30 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// O_TRUNC).
if !trunc {
d.handleMu.RLock()
- if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) {
+ var canReuseCurHandle bool
+ if d.fs.opts.lisaEnabled {
+ canReuseCurHandle = (!read || d.readFDLisa.Ok()) && (!write || d.writeFDLisa.Ok())
+ } else {
+ canReuseCurHandle = (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil())
+ }
+ d.handleMu.RUnlock()
+ if canReuseCurHandle {
// Current handles are sufficient.
- d.handleMu.RUnlock()
return nil
}
- d.handleMu.RUnlock()
}
var fdsToCloseArr [2]int32
fdsToClose := fdsToCloseArr[:0]
invalidateTranslations := false
d.handleMu.Lock()
- if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc {
+ var needNewHandle bool
+ if d.fs.opts.lisaEnabled {
+ needNewHandle = (read && !d.readFDLisa.Ok()) || (write && !d.writeFDLisa.Ok()) || trunc
+ } else {
+ needNewHandle = (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc
+ }
+ if needNewHandle {
// Get a new handle. If this file has been opened for both reading and
// writing, try to get a single handle that is usable for both:
//
@@ -1786,9 +2148,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
//
// - NOTE(b/141991141): Some filesystems may not ensure coherence
// between multiple handles for the same file.
- openReadable := !d.readFile.isNil() || read
- openWritable := !d.writeFile.isNil() || write
- h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ var (
+ openReadable bool
+ openWritable bool
+ h handle
+ err error
+ )
+ if d.fs.opts.lisaEnabled {
+ openReadable = d.readFDLisa.Ok() || read
+ openWritable = d.writeFDLisa.Ok() || write
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ openReadable = !d.readFile.isNil() || read
+ openWritable = !d.writeFile.isNil() || write
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) {
// It may not be possible to use a single handle for both
// reading and writing, since permissions on the file may have
@@ -1798,7 +2172,11 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d)
openReadable = read
openWritable = write
- h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
}
if err != nil {
d.handleMu.Unlock()
@@ -1860,9 +2238,16 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// previously opened for reading (without an FD), then existing
// translations of the file may use the internal page cache;
// invalidate those mappings.
- if d.writeFile.isNil() {
- invalidateTranslations = !d.readFile.isNil()
- atomic.StoreInt32(&d.mmapFD, h.fd)
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ invalidateTranslations = d.readFDLisa.Ok()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
+ } else {
+ if d.writeFile.isNil() {
+ invalidateTranslations = !d.readFile.isNil()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
}
} else if openWritable && d.writeFD < 0 {
atomic.StoreInt32(&d.writeFD, h.fd)
@@ -1889,24 +2274,45 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
atomic.StoreInt32(&d.mmapFD, -1)
}
- // Switch to new fids.
- var oldReadFile p9file
- if openReadable {
- oldReadFile = d.readFile
- d.readFile = h.file
- }
- var oldWriteFile p9file
- if openWritable {
- oldWriteFile = d.writeFile
- d.writeFile = h.file
- }
- // NOTE(b/141991141): Clunk old fids before making new fids visible (by
- // unlocking d.handleMu).
- if !oldReadFile.isNil() {
- oldReadFile.close(ctx)
- }
- if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
- oldWriteFile.close(ctx)
+ // Switch to new fids/FDs.
+ if d.fs.opts.lisaEnabled {
+ oldReadFD := lisafs.InvalidFDID
+ if openReadable {
+ oldReadFD = d.readFDLisa.ID()
+ d.readFDLisa = h.fdLisa
+ }
+ oldWriteFD := lisafs.InvalidFDID
+ if openWritable {
+ oldWriteFD = d.writeFDLisa.ID()
+ d.writeFDLisa = h.fdLisa
+ }
+ // NOTE(b/141991141): Close old FDs before making new fids visible (by
+ // unlocking d.handleMu).
+ if oldReadFD.Ok() {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldReadFD)
+ }
+ if oldWriteFD.Ok() && oldReadFD != oldWriteFD {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldWriteFD)
+ }
+ } else {
+ var oldReadFile p9file
+ if openReadable {
+ oldReadFile = d.readFile
+ d.readFile = h.file
+ }
+ var oldWriteFile p9file
+ if openWritable {
+ oldWriteFile = d.writeFile
+ d.writeFile = h.file
+ }
+ // NOTE(b/141991141): Clunk old fids before making new fids visible (by
+ // unlocking d.handleMu).
+ if !oldReadFile.isNil() {
+ oldReadFile.close(ctx)
+ }
+ if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
+ oldWriteFile.close(ctx)
+ }
}
}
d.handleMu.Unlock()
@@ -1930,27 +2336,29 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// Preconditions: d.handleMu must be locked.
func (d *dentry) readHandleLocked() handle {
return handle{
- file: d.readFile,
- fd: d.readFD,
+ fdLisa: d.readFDLisa,
+ file: d.readFile,
+ fd: d.readFD,
}
}
// Preconditions: d.handleMu must be locked.
func (d *dentry) writeHandleLocked() handle {
return handle{
- file: d.writeFile,
- fd: d.writeFD,
+ fdLisa: d.writeFDLisa,
+ file: d.writeFile,
+ fd: d.writeFD,
}
}
func (d *dentry) syncRemoteFile(ctx context.Context) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
- return d.syncRemoteFileLocked(ctx)
+ return d.syncRemoteFileLocked(ctx, nil /* accFsyncFDIDsLisa */)
}
// Preconditions: d.handleMu must be locked.
-func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
+func (d *dentry) syncRemoteFileLocked(ctx context.Context, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// If we have a host FD, fsyncing it is likely to be faster than an fsync
// RPC. Prefer syncing write handles over read handles, since some remote
// filesystem implementations may not sync changes made through write
@@ -1961,7 +2369,13 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.writeFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.writeFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.writeFDLisa.ID())
+ return nil
+ }
+ return d.writeFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.writeFile.isNil() {
return d.writeFile.fsync(ctx)
}
if d.readFD >= 0 {
@@ -1970,13 +2384,19 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.readFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.readFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.readFDLisa.ID())
+ return nil
+ }
+ return d.readFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.readFile.isNil() {
return d.readFile.fsync(ctx)
}
return nil
}
-func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
h := d.writeHandleLocked()
@@ -1989,7 +2409,7 @@ func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) err
return err
}
}
- if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if err := d.syncRemoteFileLocked(ctx, accFsyncFDIDsLisa); err != nil {
if !forFilesystemSync {
return err
}
@@ -2046,18 +2466,33 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
d := fd.dentry()
const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
- // Use specialFileFD.handle.file for the getattr if available, for the
- // same reason that we try to use open file handles in
- // dentry.updateFromGetattrLocked().
- var file p9file
- if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
- file = sffd.handle.file
- }
- d.metadataMu.Lock()
- err := d.updateFromGetattrLocked(ctx, file)
- d.metadataMu.Unlock()
- if err != nil {
- return linux.Statx{}, err
+ if d.fs.opts.lisaEnabled {
+ // Use specialFileFD.handle.fileLisa for the Stat if available, for the
+ // same reason that we try to use open FD in updateFromStatLisaLocked().
+ var fdLisa *lisafs.ClientFD
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ fdLisa = &sffd.handle.fdLisa
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromStatLisaLocked(ctx, fdLisa)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ } else {
+ // Use specialFileFD.handle.file for the getattr if available, for the
+ // same reason that we try to use open file handles in
+ // dentry.updateFromGetattrLocked().
+ var file p9file
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ file = sffd.handle.file
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromGetattrLocked(ctx, file)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
}
}
var stat linux.Statx
@@ -2078,7 +2513,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size)
+ return fd.dentry().listXattr(ctx, size)
}
// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 806392d50..d5cc73f33 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -33,6 +33,7 @@ func TestDestroyIdempotent(t *testing.T) {
},
syncableDentries: make(map[*dentry]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 02540a754..394aecd62 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -17,6 +17,7 @@ package gofer
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
@@ -26,10 +27,13 @@ import (
// handle represents a remote "open file descriptor", consisting of an opened
// fid (p9.File) and optionally a host file descriptor.
//
+// If lisafs is being used, fdLisa points to an open file on the server.
+//
// These are explicitly not savable.
type handle struct {
- file p9file
- fd int32 // -1 if unavailable
+ fdLisa lisafs.ClientFD
+ file p9file
+ fd int32 // -1 if unavailable
}
// Preconditions: read || write.
@@ -65,13 +69,47 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand
}, nil
}
+// Preconditions: read || write.
+func openHandleLisa(ctx context.Context, fdLisa lisafs.ClientFD, read, write, trunc bool) (handle, error) {
+ var flags uint32
+ switch {
+ case read && write:
+ flags = unix.O_RDWR
+ case read:
+ flags = unix.O_RDONLY
+ case write:
+ flags = unix.O_WRONLY
+ default:
+ panic("tried to open unreadable and unwritable handle")
+ }
+ if trunc {
+ flags |= unix.O_TRUNC
+ }
+ openFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ h := handle{
+ fdLisa: fdLisa.Client().NewFD(openFD),
+ fd: int32(hostFD),
+ }
+ return h, nil
+}
+
func (h *handle) isOpen() bool {
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Ok()
+ }
return !h.file.isNil()
}
func (h *handle) close(ctx context.Context) {
- h.file.close(ctx)
- h.file = p9file{}
+ if h.fdLisa.Client() != nil {
+ h.fdLisa.CloseBatched(ctx)
+ } else {
+ h.file.close(ctx)
+ h.file = p9file{}
+ }
if h.fd >= 0 {
unix.Close(int(h.fd))
h.fd = -1
@@ -89,19 +127,27 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs
return n, err
}
if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
- n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Read(ctx, dsts.Head().ToSlice(), offset)
+ }
+ return h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
}
// Buffer the read since p9.File.ReadAt() takes []byte.
buf := make([]byte, dsts.NumBytes())
- n, err := h.file.readAt(ctx, buf, offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Read(ctx, buf, offset)
+ } else {
+ n, err = h.file.readAt(ctx, buf, offset)
+ }
if n == 0 {
return 0, err
}
if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
return cp, cperr
}
- return uint64(n), err
+ return n, err
}
func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
@@ -115,8 +161,10 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
return n, err
}
if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
- n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Write(ctx, srcs.Head().ToSlice(), offset)
+ }
+ return h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
}
// Buffer the write since p9.File.WriteAt() takes []byte.
buf := make([]byte, srcs.NumBytes())
@@ -124,12 +172,18 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
if cp == 0 {
return 0, cperr
}
- n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Write(ctx, buf[:cp], offset)
+ } else {
+ n, err = h.file.writeAt(ctx, buf[:cp], offset)
+ }
// err takes precedence over cperr.
if err != nil {
- return uint64(n), err
+ return n, err
}
- return uint64(n), cperr
+ return n, cperr
}
type handleReadWriter struct {
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 5a3ddfc9d..0d97b60fd 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -141,18 +141,18 @@ func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, u
return fdobj, qid, iounit, err
}
-func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.ReadAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
-func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.WriteAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
func (f p9file) fsync(ctx context.Context) error {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 947dbe05f..874f9873d 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -98,6 +98,12 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error {
}
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ return nil
+ }
+ return d.writeFDLisa.Flush(ctx)
+ }
if d.writeFile.isNil() {
return nil
}
@@ -110,6 +116,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
return d.doAllocate(ctx, offset, length, func() error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ return d.writeFDLisa.Allocate(ctx, mode, offset, length)
+ }
return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -282,8 +291,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
// changes to the host.
if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode {
atomic.StoreUint32(&d.mode, newMode)
- if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
- return 0, offset, err
+ if d.fs.opts.lisaEnabled {
+ stat := linux.Statx{Mask: linux.STATX_MODE, Mode: uint16(newMode)}
+ failureMask, failureErr, err := d.controlFDLisa.SetStat(ctx, &stat)
+ if err != nil {
+ return 0, offset, err
+ }
+ if failureMask != 0 {
+ return 0, offset, failureErr
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
+ return 0, offset, err
+ }
}
}
}
@@ -677,7 +697,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
+ return fd.dentry().syncCachedFile(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
index 226790a11..5d4009832 100644
--- a/pkg/sentry/fsimpl/gofer/revalidate.go
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -15,7 +15,9 @@
package gofer
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -234,28 +236,54 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// Lock metadata on all dentries *before* getting attributes for them.
state.lockAllMetadata()
- stats, err := state.start.file.multiGetAttr(ctx, state.names)
- if err != nil {
- return err
+
+ var (
+ stats []p9.FullStat
+ statsLisa []linux.Statx
+ numStats int
+ )
+ if fs.opts.lisaEnabled {
+ var err error
+ statsLisa, err = state.start.controlFDLisa.WalkStat(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(statsLisa)
+ } else {
+ var err error
+ stats, err = state.start.file.multiGetAttr(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(stats)
}
i := -1
for d := state.popFront(); d != nil; d = state.popFront() {
i++
- found := i < len(stats)
+ found := i < numStats
if i == 0 && len(state.names[0]) == 0 {
if found && !d.isSynthetic() {
// First dentry is where the search is starting, just update attributes
// since it cannot be replaced.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: acquired by lockAllMetadata.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ }
}
d.metadataMu.Unlock() // +checklocksforce: see above.
continue
}
- // Note that synthetic dentries will always fails the comparison check
- // below.
- if !found || d.qidPath != stats[i].QID.Path {
+ // Note that synthetic dentries will always fail this comparison check.
+ var shouldInvalidate bool
+ if fs.opts.lisaEnabled {
+ shouldInvalidate = !found || d.inoKey != inoKeyFromStat(&statsLisa[i])
+ } else {
+ shouldInvalidate = !found || d.qidPath != stats[i].QID.Path
+ }
+ if shouldInvalidate {
d.metadataMu.Unlock() // +checklocksforce: see above.
if !found && d.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to replace
@@ -298,7 +326,11 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// The file at this path hasn't changed. Just update cached metadata.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: see above.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ }
d.metadataMu.Unlock()
}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index 8dcbc61ed..475322527 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/safemem"
@@ -112,10 +113,19 @@ func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
return err
}
}
- if !d.readFile.isNil() || !d.writeFile.isNil() {
- d.fs.savedDentryRW[d] = savedDentryRW{
- read: !d.readFile.isNil(),
- write: !d.writeFile.isNil(),
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() || d.writeFDLisa.Ok() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: d.readFDLisa.Ok(),
+ write: d.writeFDLisa.Ok(),
+ }
+ }
+ } else {
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
}
}
d.dirMu.Lock()
@@ -177,25 +187,37 @@ func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRest
return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
}
fs.opts.fd = fd
- if err := fs.dial(ctx); err != nil {
- return err
- }
fs.inoByQIDPath = make(map[uint64]uint64)
+ fs.inoByKey = make(map[inoKey]uint64)
- // Restore the filesystem root.
- ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fs.opts.aname)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- return err
- }
- attachFile := p9file{attached}
- qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
- if err != nil {
- return err
- }
- if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ rootInode, err := fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFileLisa(ctx, rootInode, &opts); err != nil {
+ return err
+ }
+ } else {
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
}
// Restore remaining dentries.
@@ -283,6 +305,55 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM
return nil
}
+func (d *dentry) restoreFileLisa(ctx context.Context, inode *lisafs.Inode, opts *vfs.CompleteRestoreOptions) error {
+ d.controlFDLisa = d.fs.clientLisa.NewFD(inode.ControlFD)
+
+ // Gofers do not preserve inoKey across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking inoKey.
+ //
+ // - We need to associate the new inoKey with the existing d.ino.
+ d.inoKey = inoKeyFromStat(&inode.Stat)
+ d.fs.inoMu.Lock()
+ d.fs.inoByKey[d.inoKey] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if inode.Stat.Mask&linux.STATX_SIZE != 0 {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ }
+ if d.size != inode.Stat.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, inode.Stat.Size)
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if inode.Stat.Mask&linux.STATX_MTIME != 0 {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ }
+ if want := dentryTimestampFromLisa(inode.Stat.Mtime); d.mtime != want {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromLisaStatLocked(&inode.Stat)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
// Preconditions: d is not synthetic.
func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
for _, child := range d.children {
@@ -305,19 +376,35 @@ func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.Comp
// only be detected by checking filesystem.syncableDentries). d.parent has been
// restored.
func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
- qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
- if err != nil {
- return err
- }
- if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
- return err
+ if d.fs.opts.lisaEnabled {
+ inode, err := d.parent.controlFDLisa.Walk(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFileLisa(ctx, inode, opts); err != nil {
+ return err
+ }
+ } else {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
}
return d.restoreDescendantsRecursive(ctx, opts)
}
func (fd *specialFileFD) completeRestore(ctx context.Context) error {
d := fd.dentry()
- h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ } else {
+ h, err = openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ }
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index fe15f8583..86ab70453 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -59,11 +59,6 @@ func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
- cf, ok := sockTypeToP9(ce.Type())
- if !ok {
- return syserr.ErrConnectionRefused
- }
-
// No lock ordering required as only the ConnectingEndpoint has a mutex.
ce.Lock()
@@ -77,7 +72,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
return syserr.ErrInvalidEndpointState
}
- c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue())
if err != nil {
ce.Unlock()
return err
@@ -95,7 +90,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
// UnidirectionalConnect implements
// transport.BoundEndpoint.UnidirectionalConnect.
func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
- c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{})
if err != nil {
return nil, err
}
@@ -111,25 +106,39 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
return c, nil
}
-func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.dentry.file.connect(ctx, flags)
- if err != nil {
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ if e.dentry.fs.opts.lisaEnabled {
+ hostSockFD, err := e.dentry.controlFDLisa.Connect(ctx, sockType)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewSCMEndpoint(ctx, hostSockFD, queue, e.path)
+ if serr != nil {
+ unix.Close(hostSockFD)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
+ return nil, serr
+ }
+ return c, nil
+ }
+
+ flags, ok := sockTypeToP9(sockType)
+ if !ok {
return nil, syserr.ErrConnectionRefused
}
- // Dup the fd so that the new endpoint can manage its lifetime.
- hostFD, err := unix.Dup(hostFile.FD())
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
- log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
- return nil, syserr.FromError(err)
+ return nil, syserr.ErrConnectionRefused
}
- // After duplicating, we no longer need hostFile.
- hostFile.Close()
- c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ c, serr := host.NewSCMEndpoint(ctx, hostFile.FD(), queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
+ hostFile.Close()
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
return nil, serr
}
+ // Ownership has been transferred to c.
+ hostFile.Release()
return c, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index a8d47b65b..c568bbfd2 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
@@ -149,6 +150,9 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error {
if !fd.vfsfd.IsWritable() {
return nil
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Flush(ctx)
+ }
return fd.handle.file.flush(ctx)
}
@@ -184,6 +188,9 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint
if fd.isRegularFile {
d := fd.dentry()
return d.doAllocate(ctx, offset, length, func() error {
+ if d.fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Allocate(ctx, mode, offset, length)
+ }
return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -371,10 +378,10 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- return fd.sync(ctx, false /* forFilesystemSync */)
+ return fd.sync(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
-func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// Locks to ensure it didn't race with fd.Release().
fd.releaseMu.RLock()
defer fd.releaseMu.RUnlock()
@@ -391,6 +398,13 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error
ctx.UninterruptibleSleepFinish(false)
return err
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, fd.handle.fdLisa.ID())
+ return nil
+ }
+ return fd.handle.fdLisa.Sync(ctx)
+ }
return fd.handle.file.fsync(ctx)
}()
if err != nil {
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index dbd834c67..27d9be5c4 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -35,7 +35,13 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
return target, nil
}
}
- target, err := d.file.readlink(ctx)
+ var target string
+ var err error
+ if d.fs.opts.lisaEnabled {
+ target, err = d.controlFDLisa.ReadLinkAt(ctx)
+ } else {
+ target, err = d.file.readlink(ctx)
+ }
if d.fs.opts.interop != InteropModeShared {
if err == nil {
d.haveTarget = true
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 9cbe805b9..07940b225 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,6 +17,7 @@ package gofer
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -24,6 +25,10 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
+func dentryTimestampFromLisa(t linux.StatxTimestamp) int64 {
+ return t.Sec*1e9 + int64(t.Nsec)
+}
+
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD
index 943fa180d..35feb969f 100644
--- a/pkg/sentry/seccheck/BUILD
+++ b/pkg/sentry/seccheck/BUILD
@@ -8,6 +8,8 @@ go_fieldenum(
name = "seccheck_fieldenum",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"task.go",
],
out = "seccheck_fieldenum.go",
@@ -29,6 +31,8 @@ go_library(
name = "seccheck",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"seccheck.go",
"seccheck_fieldenum.go",
"seqatomic_checkerslice_unsafe.go",
diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go
new file mode 100644
index 000000000..f36e0730e
--- /dev/null
+++ b/pkg/sentry/seccheck/execve.go
@@ -0,0 +1,65 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// ExecveInfo contains information used by the Execve checkpoint.
+//
+// +fieldenum Execve
+type ExecveInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // BinaryPath is a path to the executable binary file being switched to in
+ // the mount namespace in which it was opened.
+ BinaryPath string
+
+ // Argv is the new process image's argument vector.
+ Argv []string
+
+ // Env is the new process image's environment variables.
+ Env []string
+
+ // BinaryMode is the executable binary file's mode.
+ BinaryMode uint16
+
+ // BinarySHA256 is the SHA-256 hash of the executable binary file.
+ //
+ // Note that this requires reading the entire file into memory, which is
+ // likely to be extremely slow.
+ BinarySHA256 [32]byte
+}
+
+// ExecveReq returns fields required by the Execve checkpoint.
+func (s *state) ExecveReq() ExecveFieldSet {
+ return s.execveReq.Load()
+}
+
+// Execve is called at the Execve checkpoint.
+func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Execve(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go
new file mode 100644
index 000000000..69cb6911c
--- /dev/null
+++ b/pkg/sentry/seccheck/exit.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// ExitNotifyParentInfo contains information used by the ExitNotifyParent
+// checkpoint.
+//
+// +fieldenum ExitNotifyParent
+type ExitNotifyParentInfo struct {
+ // Exiter identifies the exiting thread. Note that by the checkpoint's
+ // definition, Exiter.ThreadID == Exiter.ThreadGroupID and
+ // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting
+ // ThreadGroup* fields is redundant.
+ Exiter TaskInfo
+
+ // ExitStatus is the exiting thread group's exit status, as reported
+ // by wait*().
+ ExitStatus linux.WaitStatus
+}
+
+// ExitNotifyParentReq returns fields required by the ExitNotifyParent
+// checkpoint.
+func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet {
+ return s.exitNotifyParentReq.Load()
+}
+
+// ExitNotifyParent is called at the ExitNotifyParent checkpoint.
+//
+// The ExitNotifyParent checkpoint occurs when a zombied thread group leader,
+// not waiting for exit acknowledgement from a non-parent ptracer, becomes the
+// last non-dead thread in its thread group and notifies its parent of its
+// exiting.
+func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.ExitNotifyParent(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go
index b6c9d44ce..e13274096 100644
--- a/pkg/sentry/seccheck/seccheck.go
+++ b/pkg/sentry/seccheck/seccheck.go
@@ -29,6 +29,8 @@ type Point uint
// PointX represents the checkpoint X.
const (
PointClone Point = iota
+ PointExecve
+ PointExitNotifyParent
// Add new Points above this line.
pointLength
@@ -47,6 +49,8 @@ const (
// registered concurrently with invocations of checkpoints).
type Checker interface {
Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+ Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error
+ ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error
}
// CheckerDefaults may be embedded by implementations of Checker to obtain
@@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone
return nil
}
+// Execve implements Checker.Execve.
+func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error {
+ return nil
+}
+
+// ExitNotifyParent implements Checker.ExitNotifyParent.
+func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error {
+ return nil
+}
+
// CheckerReq indicates what checkpoints a corresponding Checker runs at, and
// what information it requires at those checkpoints.
type CheckerReq struct {
@@ -69,7 +83,9 @@ type CheckerReq struct {
// All of the following fields indicate what fields in the corresponding
// XInfo struct will be requested at the corresponding checkpoint.
- Clone CloneFields
+ Clone CloneFields
+ Execve ExecveFields
+ ExitNotifyParent ExitNotifyParentFields
}
// Global is the method receiver of all seccheck functions.
@@ -101,7 +117,9 @@ type state struct {
// corresponding XInfo struct have been requested by any registered
// checker, are accessed using atomic memory operations, and are mutated
// with registrationMu locked.
- cloneReq CloneFieldSet
+ cloneReq CloneFieldSet
+ execveReq ExecveFieldSet
+ exitNotifyParentReq ExitNotifyParentFieldSet
}
// AppendChecker registers the given Checker to execute at checkpoints. The
@@ -110,7 +128,11 @@ type state struct {
func (s *state) AppendChecker(c Checker, req *CheckerReq) {
s.registrationMu.Lock()
defer s.registrationMu.Unlock()
+
s.cloneReq.AddFieldsLoadable(req.Clone)
+ s.execveReq.AddFieldsLoadable(req.Execve)
+ s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent)
+
s.appendCheckerLocked(c)
for _, p := range req.Points {
word, bit := p/32, p%32
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00a5e729a..f9a5b0df1 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -355,6 +355,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn
)
}
+// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message.
+func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_PKTINFO,
+ t.Arch().Width(),
+ packetInfo,
+ )
+}
+
// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
var level uint32
@@ -412,6 +423,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
@@ -453,6 +468,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index aa081e90d..dedc32dda 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1371,6 +1371,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
return &v, nil
+ case linux.IPV6_RECVPKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
@@ -2127,6 +2135,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
return nil
+ case linux.IPV6_RECVPKTINFO:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2516,7 +2533,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
linux.IPV6_RECVPATHMTU,
- linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
@@ -2742,6 +2758,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr
TClass: readCM.TClass,
HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: readCM.PacketInfo,
+ HasIPv6PacketInfo: readCM.HasIPv6PacketInfo,
+ IPv6PacketInfo: readCM.IPv6PacketInfo,
OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: readCM.SockErr,
},
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 841d5bd55..2f0eb4a6c 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -56,6 +56,17 @@ func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPack
return p
}
+// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux
+// format.
+func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo {
+ var p linux.ControlMessageIPv6PacketInfo
+ if n := copy(p.Addr[:], []byte(packetInfo.Addr)); n != len(p.Addr) {
+ panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr)))
+ }
+ p.NIC = uint32(packetInfo.NIC)
+ return p
+}
+
// errOriginToLinux maps tcpip socket origin to Linux socket origin constants.
func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 {
switch origin {
@@ -114,7 +125,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
if cmgs.HasOriginalDstAddress {
orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
}
- return IPControlMessages{
+ cm := IPControlMessages{
HasTimestamp: cmgs.HasTimestamp,
Timestamp: cmgs.Timestamp,
HasInq: cmgs.HasInq,
@@ -125,9 +136,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
TClass: cmgs.TClass,
HasIPPacketInfo: cmgs.HasIPPacketInfo,
PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo,
OriginalDstAddress: orgDstAddr,
SockErr: sockErrCmsgToLinux(cmgs.SockErr),
}
+
+ if cm.HasIPv6PacketInfo {
+ cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo)
+ }
+
+ return cm
}
// IPControlMessages contains socket control messages for IP sockets.
@@ -166,6 +184,12 @@ type IPControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo linux.ControlMessageIPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo linux.ControlMessageIPv6PacketInfo
+
// OriginalDestinationAddress holds the original destination address
// and port of the incoming packet.
OriginalDstAddress linux.SockAddr
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index a9cedcf5f..188ad3bd9 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -59,12 +59,14 @@ func (q *queue) Close() {
// q.WriterQueue.Notify(waiter.WritableEvents)
func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
- for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release(ctx)
- }
+ dataList := q.dataList
q.dataList.Reset()
q.used = 0
q.mu.Unlock()
+
+ for cur := dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release(ctx)
+ }
}
// DecRef implements RefCounter.DecRef.
diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go
index 3560e66ae..9b8c9a480 100644
--- a/pkg/sentry/time/sampler_arm64.go
+++ b/pkg/sentry/time/sampler_arm64.go
@@ -30,9 +30,9 @@ func getDefaultArchOverheadCycles() TSCValue {
// frqRatio. defaultOverheadCycles of ARM equals to that on
// x86 devided by frqRatio
cntfrq := getCNTFRQ()
- frqRatio := 1000000000 / cntfrq
+ frqRatio := 1000000000 / float64(cntfrq)
overheadCycles := (1 * 1000) / frqRatio
- return overheadCycles
+ return TSCValue(overheadCycles)
}
// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 7fd7f000d..40aff2927 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -223,6 +223,12 @@ func (rp *ResolvingPath) Final() bool {
return rp.curPart == 0 && !rp.pit.NextOk()
}
+// Pit returns a copy of rp's current path iterator. Modifying the iterator
+// does not change rp.
+func (rp *ResolvingPath) Pit() fspath.Iterator {
+ return rp.pit
+}
+
// Component returns the current path component in the stream represented by
// rp.
//
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f34bf8dd..24c2c3e6b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
+// in ControlMessages.
+func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPv6PacketInfo {
+ t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
+ } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
+ t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
// field in ControlMessages.
func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..d51c36f19 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -240,12 +240,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
// be modifying fields.
@@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -562,13 +562,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
@@ -606,6 +599,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +680,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index aef789b4c..25f5a52e3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv4(newPkt.NetworkHeader().View())
// As per RFC 791 page 30, Time to Live,
//
@@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// Even if no local information is available on the time actually
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
+ // We perform a full checksum as we may have updated options above. The IP
+ // header is relatively small so this is not expected to be an expensive
+ // operation.
+ newHdr.SetChecksum(0)
+ newHdr.SetChecksum(^newHdr.CalculateChecksum())
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1200,6 +1215,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1315,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1394,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index e7b5b3ea2..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -3373,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f99cbf8f3..f814926a3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -51,6 +51,7 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 94caaae6c..6c6107264 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -692,6 +692,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
defer r.Release()
+ if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) {
+ sent.rateLimited.Increment()
+ return
+ }
+
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
@@ -1174,13 +1179,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -1198,6 +1196,33 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) {
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonPacketTooBig:
+ return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0
+ case *icmpReasonHopLimitExceeded:
+ return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
// As per RFC 4443 section 2.4
@@ -1232,40 +1257,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonParameterProblem:
- icmpHdr.SetType(header.ICMPv6ParamProblem)
- icmpHdr.SetCode(reason.code)
- icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.paramProblem
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonPacketTooBig:
- icmpHdr.SetType(header.ICMPv6PacketTooBig)
- icmpHdr.SetCode(header.ICMPv6UnusedCode)
- counter = sent.packetTooBig
- case *icmpReasonHopLimitExceeded:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.timeExceeded
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetTypeSpecific(typeSpecific)
+
dataRange := newPkt.Data().AsRange()
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 3b4c235fa..03d9f425c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -1435,6 +1436,8 @@ func TestPacketQueing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
Clock: clock,
})
+ // Make sure ICMP rate limiting doesn't get in our way.
+ s.SetICMPLimit(rate.Inf)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c824e27fa..dab99d00d 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv6(newPkt.NetworkHeader().View())
// As per RFC 8200 section 3,
//
@@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1987,6 +1990,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv6Type]struct{}
}
ids []uint32
@@ -1998,7 +2004,8 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- fragmentation *fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
+ icmpRateLimiter *stack.ICMPRateLimiter
}
// Number returns the ipv6 protocol number.
@@ -2082,6 +2089,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -2167,6 +2181,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2263,6 +2289,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
+ // Set default ICMP rate limiting to Linux defaults.
+ //
+ // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big)
+ // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt.
+ defaultIcmpTypes := make(map[header.ICMPv6Type]struct{})
+ for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ {
+ switch i {
+ case header.ICMPv6PacketTooBig:
+ // Do not rate limit packet too big by default.
+ default:
+ defaultIcmpTypes[i] = struct{}{}
+ }
+ }
+ p.mu.icmpRateLimitedTypes = defaultIcmpTypes
+
return p
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 0735ebb23..e5286081e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -3373,7 +3373,8 @@ func TestForwarding(t *testing.T) {
ipHeaderLength := header.IPv6MinimumSize
icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen
+ totalLength := ipHeaderLength + payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
@@ -3391,7 +3392,7 @@ func TestForwarding(t *testing.T) {
copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ PayloadLength: uint16(payloadLength),
TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
@@ -3521,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
t.Error(err)
}
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv6ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+
+ // Calculate the UDP checksum and set it.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(nil, sum)
+ udpH.SetChecksum(^udpH.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 34ac62444..b0b2d0afd 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -170,10 +170,14 @@ type SocketOptions struct {
// message is passed with incoming packets.
receiveTClassEnabled uint32
- // receivePacketInfoEnabled is used to specify if more inforamtion is
- // provided with incoming packets such as interface index and address.
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv4 packets.
receivePacketInfoEnabled uint32
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv6 packets.
+ receiveIPv6PacketInfoEnabled uint32
+
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
@@ -360,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
+// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0
+}
+
+// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
+}
+
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..4fb7e9adb 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
@@ -209,27 +215,38 @@ type bucket struct {
tuples tupleList
}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
+ }
+
+ return nil, false
+}
+
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
+ srcPort: transportHeader.SourcePort(),
dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
@@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
- }
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
@@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if conn.manip == manipDestination && dir == dirOriginal {
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
+ pkt.NatDone = true
+ } else if conn.manip == manipSource && dir == dirReply {
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if conn.manip == manipSource && dir == dirOriginal {
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if conn.manip == manipDestination && dir == dirReply {
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
return false
}
@@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
return false
}
@@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber, header.UDPProtocolNumber:
+ default:
+ // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable
+ // connections.
return
}
@@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ conn.updateLocked(pkt, hook)
ct.insertConn(conn)
}
@@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
conn, _ := ct.connForTID(tid)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..3617b6dd0 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -482,11 +482,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..de5997e9e 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -206,34 +206,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
+ port := st.Port
+
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ if port == 0 {
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ }
+ case header.TCPProtocolNumber:
+ if port == 0 {
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
}
+ }
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ // Set up conection for matching NAT rule. Only the first packet of the
+ // connection comes here. Other packets will be manipulated in connection
+ // tracking.
+ //
+ // Does nothing if the protocol does not support connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
return RuleAccept, 0
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 29c22bfd4..bf248ef20 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -335,9 +335,45 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
// tell if a noop connection should be inserted at Input hook. Once conntrack
// redefines the manipulation field as mutable, we won't need the special noop
// connection.
- if pk.NatDone {
- newPk.NatDone = true
+ newPk.NatDone = pk.NatDone
+ return newPk
+}
+
+// DeepCopyForForwarding creates a deep copy of the packet buffer for
+// forwarding.
+//
+// The returned packet buffer will have the network and transport headers
+// set if the original packet buffer did.
+func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
+ newPk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: reservedHeaderBytes,
+ Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
+ IsForwardedPacket: true,
+ })
+
+ {
+ consumeBytes := pk.NetworkHeader().View().Size()
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
+ }
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
+
+ {
+ consumeBytes := pk.TransportHeader().View().Size()
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
+ }
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
+ }
+
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ newPk.NatDone = pk.NatDone
+
return newPk
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index a9ce148b9..c5e896295 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -451,6 +451,12 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // IPv6PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo IPv6PacketInfo
+
// HasOriginalDestinationAddress indicates whether OriginalDstAddress is
// set.
HasOriginalDstAddress bool
@@ -1164,6 +1170,14 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// IPv6PacketInfo is the message structure for IPV6_PKTINFO.
+//
+// +stateify savable
+type IPv6PacketInfo struct {
+ Addr Address
+ NIC NICID
+}
+
// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max send buffer sizes.
type SendBufferSizeOption struct {
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..7c998eaae 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 28b49c6be..bdf4a64b9 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestSNAT(t *testing.T) {
+ const listenPort = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+
+ nattedClientAddr tcpip.Address
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ {
+ name: "IPv6 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ipt := routerStack.IPTables()
+ filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err)
+ }
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 947bcc7b1..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 1f30e5adb..e4a64e191 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -82,11 +82,9 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
// +checklocks:mu
- netProto tcpip.NetworkProtocolNumber
- // +checklocks:mu
closed bool
// +checklocks:mu
- bound bool
+ boundNetProto tcpip.NetworkProtocolNumber
// +checklocks:mu
boundNIC tcpip.NICID
@@ -98,10 +96,10 @@ type endpoint struct {
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -137,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -150,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -211,7 +208,7 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
- proto := ep.netProto
+ proto := ep.boundNetProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
@@ -294,30 +291,41 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
defer ep.mu.Unlock()
netProto := tcpip.NetworkProtocolNumber(addr.Port)
- if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto {
- // If the NIC being bound is the same then just return success.
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
- ep.netProto = netProto
-
+ ep.boundNetProto = netProto
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: ep.boundNIC,
+ Port: uint16(ep.boundNetProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -473,7 +481,7 @@ func (*endpoint) State() uint32 {
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
defer ep.mu.RUnlock()
- return &stack.TransportEndpointInfo{NetProto: ep.netProto}
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index d2768db7b..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -57,9 +58,8 @@ func (ep *endpoint) afterLoad() {
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
ep.rcvMu.Lock()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 5148fe157..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -80,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -114,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 03c9fafa1..ff0a5df9c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -401,43 +401,6 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
@@ -521,7 +484,40 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.
ctx.cleanupCompletedHandshake(h)
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.mu.Lock()
+ e.pendingAccepted.Add(1)
+ e.mu.Unlock()
+ defer e.pendingAccepted.Done()
+
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ if e.accepted == (accepted{}) {
+ // If the listener has transitioned out of the listen state (accepted
+ // is the zero value), the new endpoint is reset instead.
+ return false
+ }
+ if e.accepted.acceptQueueIsFullLocked() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.accepted.endpoints.PushBack(h.ep)
+ atomic.AddInt32(&e.synRcvdCount, -1)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
}()
return nil
@@ -544,11 +540,15 @@ func (e *endpoint) synRcvdBacklogFull() bool {
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
+ full := e.accepted.acceptQueueIsFullLocked()
e.acceptMu.Unlock()
return full
}
+func (a *accepted) acceptQueueIsFullLocked() bool {
+ return a.endpoints.Len() == a.cap
+}
+
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
@@ -627,12 +627,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.accepted.acceptQueueIsFullLocked() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
// retransmitted by the sender anyway and we can
// complete the connection at the time of retransmit if
// the backlog has space.
+ e.acceptMu.Unlock()
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -654,6 +659,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
+ e.acceptMu.Unlock()
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -695,6 +701,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -706,6 +713,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -723,6 +731,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -755,20 +764,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.accepted.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5d8e18484..80cd07218 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -30,6 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// InitialRTO is the initial retransmission timeout.
+// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
+const InitialRTO = time.Second
+
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
@@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error {
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
}
+ if n&notifyShutdown != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
s := h.ep.segmentQueue.dequeue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d2b8f298f..407ab2664 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -187,6 +187,8 @@ const (
// say TIME_WAIT.
notifyTickleWorker
notifyError
+ // notifyShutdown means that a connecting socket was shutdown.
+ notifyShutdown
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -315,7 +317,10 @@ type accepted struct {
// belong to one list at a time, and endpoints are already stored in the
// dispatcher's list.
endpoints list.List `state:".([]*endpoint)"`
- cap int
+
+ // cap is the maximum number of endpoints that can be in the accepted endpoint
+ // list.
+ cap int
}
// endpoint represents a TCP endpoint. This struct serves as the interface
@@ -333,7 +338,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.accepted.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -573,6 +578,7 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
+ // +checklocks:acceptMu
accepted accepted
// The following are only used from the protocol goroutine, and
@@ -2060,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
@@ -2380,6 +2386,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
+
+ if e.EndpointState().connecting() {
+ // When calling shutdown(2) on a connecting socket, the endpoint must
+ // enter the error state. But this logic cannot belong to the shutdownLocked
+ // method because that method is called during a close(2) (and closing a
+ // connecting socket is not an error).
+ e.resetConnectionLocked(&tcpip.ErrConnectionReset{})
+ e.notifyProtocolGoroutine(notifyShutdown)
+ e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
+ return nil
+ }
+
return e.shutdownLocked(flags)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f2e8b3840..381f4474d 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
+ e.acceptMu.Lock()
backlog := e.accepted.cap
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 2e6ea06f5..2d5fdda19 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
DataSize: seg.data.Size(),
SegMemSize: seg.segMemSize(),
}
- if diff := cmp.Diff(got, want); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%s differs (-want +got):\n%s", name, diff)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 58817371e..6f1ee3816 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1656,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestShutdownConnectingSocket(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ shutdownMode tcpip.ShutdownFlags
+ }{
+ {"ShutdownRead", tcpip.ShutdownRead},
+ {"ShutdownWrite", tcpip.ShutdownWrite},
+ {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with
+ // the handshake process.
+ c.Create(-1)
+
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ // Start connection attempt.
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
+ }
+
+ // Check the SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ if err := c.EP.Shutdown(test.shutdownMode); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ // The endpoint internal state is updated immediately.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ select {
+ case <-ch:
+ default:
+ t.Fatal("endpoint was not notified")
+ }
+
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
+
+ // If the endpoint is not properly shutdown, it'll re-attempt to connect
+ // by sending another ACK packet.
+ c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
+ })
+ }
+}
+
func TestSynSent(t *testing.T) {
for _, test := range []struct {
name string
@@ -1679,7 +1744,7 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
@@ -1995,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
)
// Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the FIN but DON't ACK IT.
checker.IPv4(t, c.GetPacket(),
@@ -2011,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// Cause a RST to be generated by closing the read end now since we have
// unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the RST
checker.IPv4(t, c.GetPacket(),
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4255457f9..b355fa7eb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -243,19 +243,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
cm.HasTOS = true
cm.TOS = p.tos
}
+
+ if e.ops.GetReceivePacketInfo() {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
case header.IPv6ProtocolNumber:
if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
+
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ cm.HasIPv6PacketInfo = true
+ cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: p.packetInfo.NIC,
+ Addr: p.packetInfo.DestinationAddr,
+ }
+ }
default:
panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
- if e.ops.GetReceivePacketInfo() {
- cm.HasIPPacketInfo = true
- cm.PacketInfo = p.packetInfo
- }
+
if e.ops.GetReceiveOriginalDstAddress() {
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 3719b0dc7..b3199489c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1369,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
func TestReadIPPacketInfo(t *testing.T) {
tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ checker func(tcpip.NICID) checker.ControlMessagesChecker
}{
{
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ LocalAddr: stackAddr,
+ DestinationAddr: stackAddr,
+ })
+ },
},
{
name: "IPv4 multicast",
proto: header.IPv4ProtocolNumber,
flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: multicastAddr,
+ DestinationAddr: multicastAddr,
+ })
+ },
},
{
name: "IPv4 broadcast",
proto: header.IPv4ProtocolNumber,
flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: broadcastAddr,
+ DestinationAddr: broadcastAddr,
+ })
+ },
},
{
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: stackV6Addr,
+ })
+ },
},
{
name: "IPv6 multicast",
proto: header.IPv6ProtocolNumber,
flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: multicastV6Addr,
+ })
+ },
},
}
@@ -1449,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- c.ep.SocketOptions().SetReceivePacketInfo(true)
+ switch f := test.flow.netProto(); f {
+ case header.IPv4ProtocolNumber:
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
+ case header.IPv6ProtocolNumber:
+ c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
+ default:
+ t.Fatalf("unhandled protocol number = %d", f)
+ }
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
+ testRead(c, test.flow, test.checker(c.nicID))
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)