From f1ce97294bfc835a488a1607ad1b36ed349b474e Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Thu, 14 Jan 2021 11:22:24 -0800 Subject: Remove impossible errors Commit 25b5ec7 moved link address resolution out of the transport layer; special handling of link address resolution is no longer necessary in tcp. PiperOrigin-RevId: 351839254 --- pkg/tcpip/transport/tcp/accept.go | 5 +--- pkg/tcpip/transport/tcp/connect.go | 8 ++---- pkg/tcpip/transport/tcp/endpoint.go | 57 ++----------------------------------- 3 files changed, 7 insertions(+), 63 deletions(-) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 9e8872fc9..6921de0f1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -305,10 +305,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) - if err := h.start(); err != nil { - l.cleanupFailedHandshake(h) - return nil, err - } + h.start() return h, nil } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f45d26a87..6cdbb8bee 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -53,7 +53,6 @@ const ( wakerForNotification = iota wakerForNewSegment wakerForResend - wakerForResolution ) const ( @@ -460,9 +459,9 @@ func (h *handshake) processSegments() *tcpip.Error { return nil } -// start resolves the route if necessary and sends the first -// SYN/SYN-ACK. -func (h *handshake) start() *tcpip.Error { +// start sends the first SYN/SYN-ACK. It does not block, even if link address +// resolution is required. +func (h *handshake) start() { h.startTime = time.Now() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled @@ -503,7 +502,6 @@ func (h *handshake) start() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) - return nil } // complete completes the TCP 3-way handshake initiated by h.start(). diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ddbed7e46..a4508e871 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2325,68 +2325,17 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - if err := e.startMainLoop(handshake); err != nil { - return err - } - } - - return tcpip.ErrConnectStarted -} - -// startMainLoop sends the initial SYN and starts the main loop for the -// endpoint. -func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { - preloop := func() *tcpip.Error { if handshake { h := e.newHandshake() e.setEndpointState(StateSynSent) - if err := h.start(); err != nil { - e.lastErrorMu.Lock() - e.lastError = err - e.lastErrorMu.Unlock() - - e.setEndpointState(StateError) - e.hardError = err - - // Call cleanupLocked to free up any reservations. - e.cleanupLocked() - return err - } + h.start() } e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - return nil - } - - if e.route.IsResolutionRequired() { - // If the endpoint is closed between releasing e.mu and the goroutine below - // acquiring it, make sure that cleanup is deferred to the new goroutine. e.workerRunning = true - - // Sending the initial SYN may block due to route resolution; do it in a - // separate goroutine to avoid blocking the syscall goroutine. - go func() { // S/R-SAFE: will be drained before save. - e.mu.Lock() - if err := preloop(); err != nil { - e.workerRunning = false - e.mu.Unlock() - return - } - e.mu.Unlock() - _ = e.protocolMainLoop(handshake, nil) - }() - return nil + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - // No route resolution is required, so we can send the initial SYN here without - // blocking. This will hopefully reduce overall latency by overlapping time - // spent waiting for a SYN-ACK and time spent spinning up a new goroutine - // for the main loop. - if err := preloop(); err != nil { - return err - } - e.workerRunning = true - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. - return nil + return tcpip.ErrConnectStarted } // ConnectEndpoint is not supported. -- cgit v1.2.3 From dbe4176565b56d9e2f5395e410468a4c98aafd37 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 14 Jan 2021 13:41:25 -0800 Subject: Check for existence before permissions Return EEXIST when overwritting a file as long as the caller has exec permission on the parent directory, even if the caller doesn't have write permission. Also reordered the mount write check, which happens before permission is checked. Closes #5164 PiperOrigin-RevId: 351868123 --- pkg/sentry/fsimpl/gofer/filesystem.go | 82 +++++++++++++-------------------- pkg/sentry/fsimpl/kernfs/filesystem.go | 7 ++- pkg/sentry/fsimpl/overlay/filesystem.go | 17 ++++--- pkg/sentry/fsimpl/tmpfs/filesystem.go | 9 +++- test/syscalls/linux/mkdir.cc | 33 +++++++++++++ 5 files changed, 91 insertions(+), 57 deletions(-) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index df27554d3..91d5dc174 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -407,33 +407,44 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } name := rp.Component() if name == "." || name == ".." { return syserror.EEXIST } - if len(name) > maxFilenameLen { - return syserror.ENAMETOOLONG - } if parent.isDeleted() { return syserror.ENOENT } + + parent.dirMu.Lock() + defer parent.dirMu.Unlock() + + child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) + switch { + case err != nil && err != syserror.ENOENT: + return err + case child != nil: + return syserror.EEXIST + } + mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { return err } defer mnt.EndWrite() - parent.dirMu.Lock() - defer parent.dirMu.Unlock() + + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return err + } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if parent.isSynthetic() { - if child := parent.children[name]; child != nil { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if createInSyntheticDir == nil { return syserror.EPERM } @@ -449,47 +460,20 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } - if fs.opts.interop == InteropModeShared { - if child := parent.children[name]; child != nil && child.isSynthetic() { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - // The existence of a non-synthetic dentry at name would be inconclusive - // because the file it represents may have been deleted from the remote - // filesystem, so we would need to make an RPC to revalidate the dentry. - // Just attempt the file creation RPC instead. If a file does exist, the - // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a - // stale dentry exists, the dentry will fail revalidation next time it's - // used. - if err := createInRemoteDir(parent, name, &ds); err != nil { - return err - } - ev := linux.IN_CREATE - if dir { - ev |= linux.IN_ISDIR - } - parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) - return nil - } - if child := parent.children[name]; child != nil { - return syserror.EEXIST - } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } - // No cached dentry exists; however, there might still be an existing file - // at name. As above, we attempt the file creation RPC anyway. + // No cached dentry exists; however, in InteropModeShared there might still be + // an existing file at name. Just attempt the file creation RPC anyways. If a + // file does exist, the RPC will fail with EEXIST like we would have. if err := createInRemoteDir(parent, name, &ds); err != nil { return err } - if child, ok := parent.children[name]; ok && child == nil { - // Delete the now-stale negative dentry. - delete(parent.children, name) + if fs.opts.interop != InteropModeShared { + if child, ok := parent.children[name]; ok && child == nil { + // Delete the now-stale negative dentry. + delete(parent.children, name) + } + parent.touchCMtime() + parent.dirents = nil } - parent.touchCMtime() - parent.dirents = nil ev := linux.IN_CREATE if dir { ev |= linux.IN_ISDIR diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index e77523f22..a7a553619 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -208,7 +208,9 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { - if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayExec); err != nil { return err } if name == "." || name == ".." { @@ -223,6 +225,9 @@ func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string if parent.VFSDentry().IsDead() { return syserror.ENOENT } + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite); err != nil { + return err + } return nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index d55bdc97f..e46f593c7 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -480,9 +480,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return err - } name := rp.Component() if name == "." || name == ".." { return syserror.EEXIST @@ -490,11 +487,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.vfsd.IsDead() { return syserror.ENOENT } - mnt := rp.Mount() - if err := mnt.CheckBeginWrite(); err != nil { + + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } - defer mnt.EndWrite() + parent.dirMu.Lock() defer parent.dirMu.Unlock() @@ -514,6 +511,14 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.ENOENT } + mnt := rp.Mount() + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + return err + } // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 9296db2fb..453e41d11 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -153,7 +153,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if err != nil { return err } - if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + + // Order of checks is important. First check if parent directory can be + // executed, then check for existence, and lastly check if mount is writable. + if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return err } name := rp.Component() @@ -179,6 +182,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return err } defer mnt.EndWrite() + + if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + return err + } if err := create(parentDir, name); err != nil { return err } diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 27758203d..11fbfa5c5 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -82,6 +82,39 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { SyscallFailsWithErrno(EACCES)); } +TEST_F(MkdirTest, DirAlreadyExists) { + // Drop capabilities that allow us to override file and directory permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); + auto dir = JoinPath(dirname_.c_str(), "foo"); + EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallSucceeds()); + + struct { + int mode; + int err; + } tests[] = { + {.mode = 0000, .err = EACCES}, // No perm + {.mode = 0100, .err = EEXIST}, // Exec only + {.mode = 0200, .err = EACCES}, // Write only + {.mode = 0300, .err = EEXIST}, // Write+exec + {.mode = 0400, .err = EACCES}, // Read only + {.mode = 0500, .err = EEXIST}, // Read+exec + {.mode = 0600, .err = EACCES}, // Read+write + {.mode = 0700, .err = EEXIST}, // All + }; + for (const auto& t : tests) { + printf("mode: 0%o\n", t.mode); + EXPECT_THAT(chmod(dirname_.c_str(), t.mode), SyscallSucceeds()); + EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(t.err)); + } + + // Clean up. + EXPECT_THAT(chmod(dirname_.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(rmdir(dir.c_str()), SyscallSucceeds()); +} + TEST_F(MkdirTest, MkdirAtEmptyPath) { ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); auto fd = -- cgit v1.2.3 From 833516c139b5fde1b23abab1868798c8309eaa6b Mon Sep 17 00:00:00 2001 From: Arthur Sfez Date: Thu, 14 Jan 2021 15:14:11 -0800 Subject: Add stats for ARP Fixes #4963 Startblock: has LGTM from sbalana and then add reviewer ghanan PiperOrigin-RevId: 351886320 --- pkg/sentry/socket/netstack/netstack.go | 15 ++ pkg/tcpip/network/arp/arp.go | 27 +++- pkg/tcpip/network/arp/arp_test.go | 278 ++++++++++++++++++++++++++++----- pkg/tcpip/tcpip.go | 56 +++++++ 4 files changed, 334 insertions(+), 42 deletions(-) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 94fb425b2..03749a8bf 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -186,6 +186,21 @@ var Metrics = tcpip.Stats{ IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), }, + ARP: tcpip.ARPStats{ + PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), + DisabledPacketsReceived: mustCreateMetric("/netstack/arp/disabled_packets_received", "Number of ARP packets received from the link layer when the ARP layer is disabled."), + MalformedPacketsReceived: mustCreateMetric("/netstack/arp/malformed_packets_received", "Number of ARP packets which failed ARP header validation checks."), + RequestsReceived: mustCreateMetric("/netstack/arp/requests_received", "Number of ARP requests received."), + RequestsReceivedUnknownTargetAddress: mustCreateMetric("/netstack/arp/requests_received_unknown_addr", "Number of ARP requests received with an unknown target address."), + OutgoingRequestInterfaceHasNoLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_iface_has_no_addr", "Number of failed attempts to send an ARP request with an interface that has no network address."), + OutgoingRequestBadLocalAddressErrors: mustCreateMetric("/netstack/arp/outgoing_requests_invalid_local_addr", "Number of failed attempts to send an ARP request with a provided local address that is invalid."), + OutgoingRequestNetworkUnreachableErrors: mustCreateMetric("/netstack/arp/outgoing_requests_network_unreachable", "Number of failed attempts to send an ARP request with a network unreachable error."), + OutgoingRequestsDropped: mustCreateMetric("/netstack/arp/outgoing_requests_dropped", "Number of ARP requests which failed to write to a link-layer endpoint."), + OutgoingRequestsSent: mustCreateMetric("/netstack/arp/outgoing_requests_sent", "Number of ARP requests sent."), + RepliesReceived: mustCreateMetric("/netstack/arp/replies_received", "Number of ARP replies received."), + OutgoingRepliesDropped: mustCreateMetric("/netstack/arp/outgoing_replies_dropped", "Number of ARP replies which failed to write to a link-layer endpoint."), + OutgoingRepliesSent: mustCreateMetric("/netstack/arp/outgoing_replies_sent", "Number of ARP replies sent."), + }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3d5c0d270..3259d052f 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -119,21 +119,28 @@ func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *t } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats().ARP + stats.PacketsReceived.Increment() + if !e.isEnabled() { + stats.DisabledPacketsReceived.Increment() return } h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { + stats.MalformedPacketsReceived.Increment() return } switch h.Op() { case header.ARPRequest: + stats.RequestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.nud == nil { if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.RequestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -142,6 +149,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + stats.RequestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -177,9 +185,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. - _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) + if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { + stats.OutgoingRepliesDropped.Increment() + } else { + stats.OutgoingRepliesSent.Increment() + } case header.ARPReply: + stats.RepliesReceived.Increment() addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) @@ -233,6 +246,8 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + stats := p.stack.Stats().ARP + if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } @@ -241,15 +256,18 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot if len(localAddr) == 0 { addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if err != nil { + stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment() return err } if len(addr.Address) == 0 { + stats.OutgoingRequestNetworkUnreachableErrors.Increment() return tcpip.ErrNetworkUnreachable } localAddr = addr.Address } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + stats.OutgoingRequestBadLocalAddressErrors.Increment() return tcpip.ErrBadLocalAddress } @@ -269,7 +287,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + stats.OutgoingRequestsDropped.Increment() + return err + } + stats.OutgoingRequestsSent.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a25cba513..6b61f57ad 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -240,6 +240,10 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr, remoteAddr} { t.Run(strconv.Itoa(i), func(t *testing.T) { + expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1 + expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1 + expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1 + inject(address) pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { @@ -249,6 +253,9 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } + if got := rep.Op(); got != header.ARPReply { + t.Fatalf("got Op = %d, want = %d", got, header.ARPReply) + } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } @@ -261,6 +268,16 @@ func TestDirectRequest(t *testing.T) { if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) } + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived) + } + if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived) + } + if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent) + } }) } @@ -273,6 +290,84 @@ func TestDirectRequest(t *testing.T) { if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } + if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got) + } +} + +func TestMalformedPacket(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + v := make(buffer.View, header.ARPSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got) + } +} + +func TestDisabledEndpoint(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) + if err != nil { + t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err) + } + ep.Disable() + + v := make(buffer.View, header.ARPSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got) + } +} + +func TestDirectReply(t *testing.T) { + c := newTestContext(t, false) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPReply) + + copy(h.HardwareAddressSender(), senderMAC) + copy(h.ProtocolAddressSender(), senderIPv4) + copy(h.HardwareAddressTarget(), stackLinkAddr) + copy(h.ProtocolAddressTarget(), stackAddr) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + + c.linkEP.InjectInbound(arp.ProtocolNumber, pkt) + + if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } + if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got) + } } func TestDirectRequestWithNeighborCache(t *testing.T) { @@ -311,6 +406,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + packetsRecv := c.s.Stats().ARP.PacketsReceived.Value() + requestsRecv := c.s.Stats().ARP.RequestsReceived.Value() + requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() + outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value() + // Inject an incoming ARP request. v := make(buffer.View, header.ARPSize) h := header.ARP(v) @@ -323,6 +423,13 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { Data: v.ToVectorisedView(), })) + if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want { + t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want) + } + if !test.isValid { // No packets should be sent after receiving an invalid ARP request. // There is no need to perform a blocking read here, since packets are @@ -330,9 +437,20 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if pkt, ok := c.linkEP.Read(); ok { t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) } + if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want { + t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want) + } + if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) + } + return } + if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want { + t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want) + } + // Verify an ARP response was sent. pi, ok := c.linkEP.Read() if !ok { @@ -418,6 +536,8 @@ type testInterface struct { stack.LinkEndpoint nicID tcpip.NICID + + writeErr *tcpip.Error } func (t *testInterface) ID() tcpip.NICID { @@ -441,6 +561,10 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if t.writeErr != nil { + return t.writeErr + } + var r stack.Route r.NetProto = protocol r.ResolveWith(remoteLinkAddr) @@ -458,61 +582,99 @@ func TestLinkAddressRequest(t *testing.T) { localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectedErr *tcpip.Error - expectedLocalAddr tcpip.Address - expectedRemoteLinkAddr tcpip.LinkAddress + linkErr *tcpip.Error + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + expectedRequestsSent uint64 + expectedRequestBadLocalAddressErrors uint64 + expectedRequestNetworkUnreachableErrors uint64 + expectedRequestDroppedErrors uint64 }{ { - name: "Unicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast", - nicAddr: stackAddr, - localAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: remoteLinkAddr, - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: remoteLinkAddr, + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast with unspecified source", - nicAddr: stackAddr, - remoteLinkAddr: "", - expectedLocalAddr: stackAddr, - expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + expectedRequestsSent: 1, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Multicast with unassigned address", - localAddr: testAddr, - remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 1, + expectedRequestNetworkUnreachableErrors: 0, }, { - name: "Unicast with no local address available", - remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, + name: "Unicast with no local address available", + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 1, }, { - name: "Multicast with no local address available", - remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + expectedRequestsSent: 0, + expectedRequestBadLocalAddressErrors: 0, + expectedRequestNetworkUnreachableErrors: 1, + }, + { + name: "Link error", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + linkErr: tcpip.ErrInvalidEndpointState, + expectedErr: tcpip.ErrInvalidEndpointState, + expectedRequestDroppedErrors: 1, }, } @@ -543,10 +705,24 @@ func TestLinkAddressRequest(t *testing.T) { // can mock a link address request and observe the packets sent to the // link endpoint even though the stack uses the real NIC to validate the // local address. - if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr { t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) } + if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { + t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent) + } + if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors) + } + if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors) + } + if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors { + t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors) + } + if test.expectedErr != nil { return } @@ -561,6 +737,9 @@ func TestLinkAddressRequest(t *testing.T) { } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op = %d, want = %d", got, header.ARPRequest) + } if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) } @@ -576,3 +755,22 @@ func TestLinkAddressRequest(t *testing.T) { }) } } + +func TestLinkAddressRequestWithoutNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID) + } + + if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 { + t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 002ddaf67..49d4912ad 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1591,6 +1591,59 @@ type IPStats struct { OptionUnknownReceived *StatCounter } +// ARPStats collects ARP-specific stats. +type ARPStats struct { + // PacketsReceived is the number of ARP packets received from the link layer. + PacketsReceived *StatCounter + + // DisabledPacketsReceived is the number of ARP packets received from the link + // layer when the ARP layer is disabled. + DisabledPacketsReceived *StatCounter + + // MalformedPacketsReceived is the number of ARP packets that were dropped due + // to being malformed. + MalformedPacketsReceived *StatCounter + + // RequestsReceived is the number of ARP requests received. + RequestsReceived *StatCounter + + // RequestsReceivedUnknownTargetAddress is the number of ARP requests that + // were targeted to an interface different from the one it was received on. + RequestsReceivedUnknownTargetAddress *StatCounter + + // OutgoingRequestInterfaceHasNoLocalAddressErrors is the number of failures + // to send an ARP request because the interface has no network address + // assigned to it. + OutgoingRequestInterfaceHasNoLocalAddressErrors *StatCounter + + // OutgoingRequestBadLocalAddressErrors is the number of failures to send an + // ARP request with a bad local address. + OutgoingRequestBadLocalAddressErrors *StatCounter + + // OutgoingRequestNetworkUnreachableErrors is the number of failures to send + // an ARP request with a network unreachable error. + OutgoingRequestNetworkUnreachableErrors *StatCounter + + // OutgoingRequestsDropped is the number of ARP requests which failed to write + // to a link-layer endpoint. + OutgoingRequestsDropped *StatCounter + + // OutgoingRequestSent is the number of ARP requests successfully written to a + // link-layer endpoint. + OutgoingRequestsSent *StatCounter + + // RepliesReceived is the number of ARP replies received. + RepliesReceived *StatCounter + + // OutgoingRepliesDropped is the number of ARP replies which failed to write + // to a link-layer endpoint. + OutgoingRepliesDropped *StatCounter + + // OutgoingRepliesSent is the number of ARP replies successfully written to a + // link-layer endpoint. + OutgoingRepliesSent *StatCounter +} + // TCPStats collects TCP-specific stats. type TCPStats struct { // ActiveConnectionOpenings is the number of connections opened @@ -1743,6 +1796,9 @@ type Stats struct { // IP breaks out IP-specific stats (both v4 and v6). IP IPStats + // ARP breaks out ARP-specific stats. + ARP ARPStats + // TCP breaks out TCP-specific stats. TCP TCPStats -- cgit v1.2.3 From 95371cff350ef5c22c0e0b76ef9474c16e29a6f6 Mon Sep 17 00:00:00 2001 From: Zach Koopmans Date: Thu, 14 Jan 2021 17:02:01 -0800 Subject: Don't run profiles on runc. PiperOrigin-RevId: 351906812 --- Makefile | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 284491c4f..8a1c4321c 100644 --- a/Makefile +++ b/Makefile @@ -323,6 +323,7 @@ containerd-tests: containerd-test-1.4.3 ## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm). ## BENCHMARKS_FILTER - filter to be applied to the test suite. ## BENCHMARKS_OPTIONS - options to be passed to the test. +## BENCHMARKS_PROFILE - profile options to be passed to the test. ## BENCHMARKS_PROJECT ?= gvisor-benchmarks BENCHMARKS_DATASET ?= kokoro @@ -334,7 +335,8 @@ BENCHMARKS_PLATFORMS ?= ptrace BENCHMARKS_TARGETS := //test/benchmarks/media:ffmpeg_test BENCHMARKS_FILTER := . BENCHMARKS_OPTIONS := -test.benchtime=30s -BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex $(BENCHMARKS_OPTIONS) +BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS) +BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)) @@ -344,9 +346,10 @@ init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. run_benchmark = \ ($(call header,BENCHMARK $(1) $(2)); \ set -euo pipefail; \ - if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); fi; \ export T=$$(mktemp --tmpdir logs.$(1).XXXXXX); \ - $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; \ + if test "$(1)" = "runc"; then $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; fi; \ + if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); \ + $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS) $(BENCHMARKS_PROFILE)) | tee $$T; fi; \ if test "$(BENCHMARKS_UPLOAD)" = "true"; then \ $(call run,tools/parsers:parser,parse --debug --file=$$T --runtime=$(1) --suite_name=$(BENCHMARKS_SUITE) --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)); \ fi; \ -- cgit v1.2.3 From e57ebcd37a7b9f98d80e594f2c0baf2220d7b830 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 14 Jan 2021 17:32:38 -0800 Subject: Simplify the pipe implementation. - Remove the pipe package's dependence on the buffer package, which becomes unused as a result. The buffer package is currently intended to serve two use cases, pipes and temporary buffers, and does neither optimally as a result; this change facilitates retooling the buffer package to better serve the latter. - Pass callbacks taking safemem.BlockSeq to the internal pipe I/O methods, which makes most callbacks trivial. - Fix VFS1's splice() and tee() to immediately return if a pipe returns a partial write. PiperOrigin-RevId: 351911375 --- pkg/sentry/kernel/pipe/BUILD | 2 +- pkg/sentry/kernel/pipe/pipe.go | 198 +++++++++++++++++-------------- pkg/sentry/kernel/pipe/pipe_util.go | 99 +++++++--------- pkg/sentry/kernel/pipe/save_restore.go | 26 ++++ pkg/sentry/kernel/pipe/vfs.go | 202 ++++++++++---------------------- pkg/sentry/syscalls/linux/sys_splice.go | 15 ++- test/syscalls/linux/splice.cc | 106 +++++++++++++++++ 7 files changed, 358 insertions(+), 290 deletions(-) create mode 100644 pkg/sentry/kernel/pipe/save_restore.go diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 99134e634..2c32d017d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -12,6 +12,7 @@ go_library( "pipe_util.go", "reader.go", "reader_writer.go", + "save_restore.go", "vfs.go", "writer.go", ], @@ -19,7 +20,6 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", - "//pkg/buffer", "//pkg/context", "//pkg/marshal/primitive", "//pkg/safemem", diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index b989e14c7..c551acd99 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -21,8 +21,8 @@ import ( "sync/atomic" "syscall" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -75,10 +75,18 @@ type Pipe struct { // mu protects all pipe internal state below. mu sync.Mutex `state:"nosave"` - // view is the underlying set of buffers. + // buf holds the pipe's data. buf is a circular buffer; the first valid + // byte in buf is at offset off, and the pipe contains size valid bytes. + // bufBlocks contains two identical safemem.Blocks representing buf; this + // avoids needing to heap-allocate a new safemem.Block slice when buf is + // resized. bufBlockSeq is a safemem.BlockSeq representing bufBlocks. // - // This is protected by mu. - view buffer.View + // These fields are protected by mu. + buf []byte + bufBlocks [2]safemem.Block `state:"nosave"` + bufBlockSeq safemem.BlockSeq `state:"nosave"` + off int64 + size int64 // max is the maximum size of the pipe in bytes. When this max has been // reached, writers will get EWOULDBLOCK. @@ -99,12 +107,6 @@ type Pipe struct { // // N.B. The size will be bounded. func NewPipe(isNamed bool, sizeBytes int64) *Pipe { - if sizeBytes < MinimumPipeSize { - sizeBytes = MinimumPipeSize - } - if sizeBytes > MaximumPipeSize { - sizeBytes = MaximumPipeSize - } var p Pipe initPipe(&p, isNamed, sizeBytes) return &p @@ -175,75 +177,71 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } -type readOps struct { - // left returns the bytes remaining. - left func() int64 - - // limit limits subsequence reads. - limit func(int64) - - // read performs the actual read operation. - read func(*buffer.View) (int64, error) -} - -// read reads data from the pipe into dst and returns the number of bytes -// read, or returns ErrWouldBlock if the pipe is empty. +// peekLocked passes the first count bytes in the pipe to f and returns its +// result. If fewer than count bytes are available, the safemem.BlockSeq passed +// to f will be less than count bytes in length. // -// Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.readLocked(ctx, ops) -} - -func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { +// peekLocked does not mutate the pipe; if the read consumes bytes from the +// pipe, then the caller is responsible for calling p.consumeLocked() and +// p.Notify(waiter.EventOut). (The latter must be called with p.mu unlocked.) +// +// Preconditions: +// * p.mu must be locked. +// * This pipe must have readers. +func (p *Pipe) peekLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if ops.left() == 0 { + if count == 0 { return 0, nil } - // Is the pipe empty? - if p.view.Size() == 0 { - if !p.HasWriters() { - // There are no writers, return EOF. - return 0, io.EOF + // Limit the amount of data read to the amount of data in the pipe. + if count > p.size { + if p.size == 0 { + if !p.HasWriters() { + return 0, io.EOF + } + return 0, syserror.ErrWouldBlock } - return 0, syserror.ErrWouldBlock + count = p.size } - // Limit how much we consume. - if ops.left() > p.view.Size() { - ops.limit(p.view.Size()) - } + // Prepare the view of the data to be read. + bs := p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(count)) - // Copy user data; the read op is responsible for trimming. - done, err := ops.read(&p.view) - return done, err + // Perform the read. + done, err := f(bs) + return int64(done), err } -type writeOps struct { - // left returns the bytes remaining. - left func() int64 - - // limit should limit subsequent writes. - limit func(int64) - - // write should write to the provided buffer. - write func(*buffer.View) (int64, error) -} - -// write writes data from sv into the pipe and returns the number of bytes -// written. If no bytes are written because the pipe is full (or has less than -// atomicIOBytes free capacity), write returns ErrWouldBlock. +// consumeLocked consumes the first n bytes in the pipe, such that they will no +// longer be visible to future reads. // -// Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.writeLocked(ctx, ops) +// Preconditions: +// * p.mu must be locked. +// * The pipe must contain at least n bytes. +func (p *Pipe) consumeLocked(n int64) { + p.off += n + if max := int64(len(p.buf)); p.off >= max { + p.off -= max + } + p.size -= n } -func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { +// writeLocked passes a safemem.BlockSeq representing the first count bytes of +// unused space in the pipe to f and returns the result. If fewer than count +// bytes are free, the safemem.BlockSeq passed to f will be less than count +// bytes in length. If the pipe is full or otherwise cannot accomodate a write +// of any number of bytes up to count, writeLocked returns ErrWouldBlock +// without calling f. +// +// Unlike peekLocked, writeLocked assumes that f returns the number of bytes +// written to the pipe, and increases the number of bytes stored in the pipe +// accordingly. Callers are still responsible for calling +// p.Notify(waiter.EventIn) with p.mu unlocked. +// +// Preconditions: +// * p.mu must be locked. +func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { // Can't write to a pipe with no readers. if !p.HasReaders() { return 0, syscall.EPIPE @@ -251,29 +249,59 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := ops.left() - avail := p.max - p.view.Size() - if wanted > avail { - if wanted <= atomicIOBytes { + avail := p.max - p.size + short := false + if count > avail { + if count <= atomicIOBytes { return 0, syserror.ErrWouldBlock } - ops.limit(avail) + count = avail + short = true } - // Copy user data. - done, err := ops.write(&p.view) - if err != nil { - return done, err + // Ensure that the buffer is big enough. + if newLen, oldCap := p.size+count, int64(len(p.buf)); newLen > oldCap { + // Allocate a new buffer. + newCap := oldCap * 2 + if oldCap == 0 { + newCap = 8 // arbitrary; sending individual integers across pipes is relatively common + } + for newLen > newCap { + newCap *= 2 + } + if newCap > p.max { + newCap = p.max + } + newBuf := make([]byte, newCap) + // Copy the old buffer's contents to the beginning of the new one. + safemem.CopySeq( + safemem.BlockSeqOf(safemem.BlockFromSafeSlice(newBuf)), + p.bufBlockSeq.DropFirst64(uint64(p.off)).TakeFirst64(uint64(p.size))) + // Switch to the new buffer. + p.buf = newBuf + p.bufBlocks[0] = safemem.BlockFromSafeSlice(newBuf) + p.bufBlocks[1] = p.bufBlocks[0] + p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:]) + p.off = 0 } - if done < avail { - // Non-failure, but short write. - return done, nil + // Prepare the view of the space to be written. + woff := p.off + p.size + if woff >= int64(len(p.buf)) { + woff -= int64(len(p.buf)) } - if done < wanted { - // Partial write due to full pipe. Note that this could also be - // the short write case above, we would expect a second call - // and the write to return zero bytes in this case. + bs := p.bufBlockSeq.DropFirst64(uint64(woff)).TakeFirst64(uint64(count)) + + // Perform the write. + doneU64, err := f(bs) + done := int64(doneU64) + p.size += done + if done < count || err != nil { + return done, err + } + + // If we shortened the write, adjust the returned error appropriately. + if short { return done, syserror.ErrWouldBlock } @@ -324,7 +352,7 @@ func (p *Pipe) HasWriters() bool { // Precondition: mu must be held. func (p *Pipe) rReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasReaders() && p.view.Size() != 0 { + if p.HasReaders() && p.size != 0 { ready |= waiter.EventIn } if !p.HasWriters() && p.hadWriter { @@ -350,7 +378,7 @@ func (p *Pipe) rReadiness() waiter.EventMask { // Precondition: mu must be held. func (p *Pipe) wReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasWriters() && p.view.Size() < p.max { + if p.HasWriters() && p.size < p.max { ready |= waiter.EventOut } if !p.HasReaders() { @@ -383,7 +411,7 @@ func (p *Pipe) queued() int64 { } func (p *Pipe) queuedLocked() int64 { - return p.view.Size() + return p.size } // FifoSize implements fs.FifoSizer.FifoSize. @@ -406,7 +434,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) { } p.mu.Lock() defer p.mu.Unlock() - if size < p.view.Size() { + if size < p.size { return 0, syserror.EBUSY } p.max = size diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index f665920cb..77246edbe 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,9 +21,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -44,46 +44,37 @@ func (p *Pipe) Release(context.Context) { // Read reads from the Pipe into dst. func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { - n, err := p.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(view *buffer.View) (int64, error) { - n, err := dst.CopyOutFrom(ctx, view) - dst = dst.DropFirst64(n) - view.TrimFront(n) - return n, err - }, - }) + n, err := dst.CopyOutFrom(ctx, p) if n > 0 { p.Notify(waiter.EventOut) } return n, err } +// ReadToBlocks implements safemem.Reader.ReadToBlocks for Pipe.Read. +func (p *Pipe) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + n, err := p.read(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }, true /* removeFromSrc */) + return uint64(n), err +} + +func (p *Pipe) read(count int64, f func(srcs safemem.BlockSeq) (uint64, error), removeFromSrc bool) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + n, err := p.peekLocked(count, f) + if n > 0 && removeFromSrc { + p.consumeLocked(n) + } + return n, err +} + // WriteTo writes to w from the Pipe. func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadToWriter(w, count) - if !dup { - view.TrimFront(n) - } - count -= n - return n, err - }, - } - n, err := p.read(ctx, ops) - if n > 0 { + n, err := p.read(count, func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.FromIOWriter{w}.WriteFromBlocks(srcs) + }, !dup /* removeFromSrc */) + if n > 0 && !dup { p.Notify(waiter.EventOut) } return n, err @@ -91,39 +82,31 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) // Write writes to the Pipe from src. func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(view *buffer.View) (int64, error) { - n, err := src.CopyInTo(ctx, view) - src = src.DropFirst64(n) - return n, err - }, - }) + n, err := src.CopyInTo(ctx, p) if n > 0 { p.Notify(waiter.EventIn) } return n, err } +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks for Pipe.Write. +func (p *Pipe) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + n, err := p.write(int64(srcs.NumBytes()), func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }) + return uint64(n), err +} + +func (p *Pipe) write(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + return p.writeLocked(count, f) +} + // ReadFrom reads from r to the Pipe. func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(view *buffer.View) (int64, error) { - n, err := view.WriteFromReader(r, count) - count -= n - return n, err - }, + n, err := p.write(count, func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.FromIOReader{r}.ReadToBlocks(dsts) }) if n > 0 { p.Notify(waiter.EventIn) diff --git a/pkg/sentry/kernel/pipe/save_restore.go b/pkg/sentry/kernel/pipe/save_restore.go new file mode 100644 index 000000000..f135827de --- /dev/null +++ b/pkg/sentry/kernel/pipe/save_restore.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "gvisor.dev/gvisor/pkg/safemem" +) + +// afterLoad is called by stateify. +func (p *Pipe) afterLoad() { + p.bufBlocks[0] = safemem.BlockFromSafeSlice(p.buf) + p.bufBlocks[1] = p.bufBlocks[0] + p.bufBlockSeq = safemem.BlockSeqFromSlice(p.bufBlocks[:]) +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 2d47d2e82..d5a91730d 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -16,7 +16,6 @@ package pipe import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -269,12 +268,10 @@ func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { // SpliceToNonPipe performs a splice operation from fd to a non-pipe file. func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescription, off, count int64) (int64, error) { fd.pipe.mu.Lock() - defer fd.pipe.mu.Unlock() // Cap the sequence at number of bytes actually available. - v := fd.pipe.queuedLocked() - if v < count { - count = v + if count > fd.pipe.size { + count = fd.pipe.size } src := usermem.IOSequence{ IO: fd, @@ -291,154 +288,97 @@ func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescripti n, err = out.PWrite(ctx, src, off, vfs.WriteOptions{}) } if n > 0 { - fd.pipe.view.TrimFront(n) + fd.pipe.consumeLocked(n) + } + + fd.pipe.mu.Unlock() + + if n > 0 { + fd.pipe.Notify(waiter.EventOut) } return n, err } // SpliceFromNonPipe performs a splice operation from a non-pipe file to fd. func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) { - fd.pipe.mu.Lock() - defer fd.pipe.mu.Unlock() - dst := usermem.IOSequence{ IO: fd, Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), } + var ( + n int64 + err error + ) + fd.pipe.mu.Lock() if off == -1 { - return in.Read(ctx, dst, vfs.ReadOptions{}) + n, err = in.Read(ctx, dst, vfs.ReadOptions{}) + } else { + n, err = in.PRead(ctx, dst, off, vfs.ReadOptions{}) + } + fd.pipe.mu.Unlock() + + if n > 0 { + fd.pipe.Notify(waiter.EventIn) } - return in.PRead(ctx, dst, off, vfs.ReadOptions{}) + return n, err } // CopyIn implements usermem.IO.CopyIn. Note that it is the caller's -// responsibility to trim fd.pipe.view after the read is completed. +// responsibility to call fd.pipe.consumeLocked() and +// fd.pipe.Notify(waiter.EventOut) after the read is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { - origCount := int64(len(dst)) - n, err := fd.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return int64(len(dst)) - }, - limit: func(l int64) { - dst = dst[:l] - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadAt(dst, 0) - return int64(n), err - }, + n, err := fd.pipe.peekLocked(int64(len(dst)), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(dst)), srcs) }) - if n > 0 { - fd.pipe.Notify(waiter.EventOut) - } - if err == nil && n != origCount { - return int(n), syserror.ErrWouldBlock - } return int(n), err } -// CopyOut implements usermem.IO.CopyOut. +// CopyOut implements usermem.IO.CopyOut. Note that it is the caller's +// responsibility to call fd.pipe.Notify(waiter.EventIn) after the +// write is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { - origCount := int64(len(src)) - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return int64(len(src)) - }, - limit: func(l int64) { - src = src[:l] - }, - write: func(view *buffer.View) (int64, error) { - view.Append(src) - return int64(len(src)), nil - }, + n, err := fd.pipe.writeLocked(int64(len(src)), func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(src))) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - if err == nil && n != origCount { - return int(n), syserror.ErrWouldBlock - } return int(n), err } // ZeroOut implements usermem.IO.ZeroOut. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { - origCount := toZero - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return toZero - }, - limit: func(l int64) { - toZero = l - }, - write: func(view *buffer.View) (int64, error) { - view.Grow(view.Size()+toZero, true /* zero */) - return toZero, nil - }, + n, err := fd.pipe.writeLocked(toZero, func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.ZeroSeq(dsts) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } return n, err } // CopyInTo implements usermem.IO.CopyInTo. Note that it is the caller's -// responsibility to trim fd.pipe.view after the read is completed. +// responsibility to call fd.pipe.consumeLocked() and +// fd.pipe.Notify(waiter.EventOut) after the read is completed. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { - count := ars.NumBytes() - if count == 0 { - return 0, nil - } - origCount := count - n, err := fd.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadToSafememWriter(dst, uint64(count)) - return int64(n), err - }, + return fd.pipe.peekLocked(ars.NumBytes(), func(srcs safemem.BlockSeq) (uint64, error) { + return dst.WriteFromBlocks(srcs) }) - if n > 0 { - fd.pipe.Notify(waiter.EventOut) - } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } - return n, err } // CopyOutFrom implements usermem.IO.CopyOutFrom. +// +// Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { - count := ars.NumBytes() - if count == 0 { - return 0, nil - } - origCount := count - n, err := fd.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(view *buffer.View) (int64, error) { - n, err := view.WriteFromSafememReader(src, uint64(count)) - return int64(n), err - }, + n, err := fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { + return src.ReadToBlocks(dsts) }) if n > 0 { fd.pipe.Notify(waiter.EventIn) } - if err == nil && n != origCount { - return n, syserror.ErrWouldBlock - } return n, err } @@ -481,37 +421,23 @@ func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFr } lockTwoPipes(dst.pipe, src.pipe) - defer dst.pipe.mu.Unlock() - defer src.pipe.mu.Unlock() - - n, err := dst.pipe.writeLocked(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(dstView *buffer.View) (int64, error) { - return src.pipe.readLocked(ctx, readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(srcView *buffer.View) (int64, error) { - n, err := srcView.ReadToSafememWriter(dstView, uint64(count)) - if n > 0 && removeFromSrc { - srcView.TrimFront(int64(n)) - } - return int64(n), err - }, - }) - }, + n, err := dst.pipe.writeLocked(count, func(dsts safemem.BlockSeq) (uint64, error) { + n, err := src.pipe.peekLocked(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }) + if n > 0 && removeFromSrc { + src.pipe.consumeLocked(n) + } + return uint64(n), err }) + dst.pipe.mu.Unlock() + src.pipe.mu.Unlock() + if n > 0 { dst.pipe.Notify(waiter.EventIn) - src.pipe.Notify(waiter.EventOut) + if removeFromSrc { + src.pipe.Notify(waiter.EventOut) + } } return n, err } diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 1c4cdb0dd..134051124 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,24 +29,23 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) { return 0, syserror.EINVAL } - + if opts.Length == 0 { + return 0, nil + } if opts.Length > int64(kernel.MAX_RW_COUNT) { opts.Length = int64(kernel.MAX_RW_COUNT) } var ( - total int64 n int64 err error inCh chan struct{} outCh chan struct{} ) - for opts.Length > 0 { + for { n, err = fs.Splice(t, outFile, inFile, opts) - opts.Length -= n - total += n - if err != syserror.ErrWouldBlock { + if n != 0 || err != syserror.ErrWouldBlock { break } else if err == syserror.ErrWouldBlock && nonBlocking { break @@ -87,13 +86,13 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB } } - if total > 0 { + if n > 0 { // On Linux, inotify behavior is not very consistent with splice(2). We try // our best to emulate Linux for very basic calls to splice, where for some // reason, events are generated for output files, but not input files. outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0) } - return total, err + return n, err } // Sendfile implements linux system call sendfile(2). diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index c2369db54..e5730a606 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -483,6 +483,112 @@ TEST(SpliceTest, TwoPipes) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); } +TEST(SpliceTest, TwoPipesPartialRead) { + // Create two pipes. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor first_rfd(fds[0]); + const FileDescriptor first_wfd(fds[1]); + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor second_rfd(fds[0]); + const FileDescriptor second_wfd(fds[1]); + + // Write half a page of data to the first pipe. + std::vector buf(kPageSize / 2); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize / 2)); + + // Attempt to splice one page from the first pipe to the second; it should + // immediately return after splicing the half-page previously written to the + // first pipe. + EXPECT_THAT( + splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, kPageSize, 0), + SyscallSucceedsWithValue(kPageSize / 2)); +} + +TEST(SpliceTest, TwoPipesPartialWrite) { + // Create two pipes. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor first_rfd(fds[0]); + const FileDescriptor first_wfd(fds[1]); + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor second_rfd(fds[0]); + const FileDescriptor second_wfd(fds[1]); + + // Write two pages of data to the first pipe. + std::vector buf(2 * kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(2 * kPageSize)); + + // Limit the second pipe to two pages, then write one page of data to it. + ASSERT_THAT(fcntl(second_wfd.get(), F_SETPIPE_SZ, 2 * kPageSize), + SyscallSucceeds()); + ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size() / 2), + SyscallSucceedsWithValue(kPageSize)); + + // Attempt to splice two pages from the first pipe to the second; it should + // immediately return after splicing the first page previously written to the + // first pipe. + EXPECT_THAT(splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, + 2 * kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); +} + +TEST(TeeTest, TwoPipesPartialRead) { + // Create two pipes. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor first_rfd(fds[0]); + const FileDescriptor first_wfd(fds[1]); + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor second_rfd(fds[0]); + const FileDescriptor second_wfd(fds[1]); + + // Write half a page of data to the first pipe. + std::vector buf(kPageSize / 2); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize / 2)); + + // Attempt to tee one page from the first pipe to the second; it should + // immediately return after copying the half-page previously written to the + // first pipe. + EXPECT_THAT(tee(first_rfd.get(), second_wfd.get(), kPageSize, 0), + SyscallSucceedsWithValue(kPageSize / 2)); +} + +TEST(TeeTest, TwoPipesPartialWrite) { + // Create two pipes. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor first_rfd(fds[0]); + const FileDescriptor first_wfd(fds[1]); + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor second_rfd(fds[0]); + const FileDescriptor second_wfd(fds[1]); + + // Write two pages of data to the first pipe. + std::vector buf(2 * kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(2 * kPageSize)); + + // Limit the second pipe to two pages, then write one page of data to it. + ASSERT_THAT(fcntl(second_wfd.get(), F_SETPIPE_SZ, 2 * kPageSize), + SyscallSucceeds()); + ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size() / 2), + SyscallSucceedsWithValue(kPageSize)); + + // Attempt to tee two pages from the first pipe to the second; it should + // immediately return after copying the first page previously written to the + // first pipe. + EXPECT_THAT(tee(first_rfd.get(), second_wfd.get(), 2 * kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); +} + TEST(SpliceTest, TwoPipesCircular) { // This test deadlocks the sentry on VFS1 because VFS1 splice ordering is // based on fs.File.UniqueID, which does not prevent circular ordering between -- cgit v1.2.3