diff options
author | Adin Scannell <ascannell@google.com> | 2019-09-12 17:42:14 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-09-12 17:43:27 -0700 |
commit | 7c6ab6a219f37a1d4c18ced4a602458fcf363f85 (patch) | |
tree | 77ab0d0aaad32a81b09ee7ad55b97b80aaec5402 /pkg/tcpip | |
parent | df5d377521e625aeb8f4fe18bd1d9974dbf9998c (diff) |
Implement splice methods for pipes and sockets.
This also allows the tee(2) implementation to be enabled, since dup can now be
properly supported via WriteTo.
Note that this change necessitated some minor restructoring with the
fs.FileOperations splice methods. If the *fs.File is passed through directly,
then only public API methods are accessible, which will deadlock immediately
since the locking is already done by fs.Splice. Instead, we pass through an
abstract io.Reader or io.Writer, which elide locks and use the underlying
fs.FileOperations directly.
PiperOrigin-RevId: 268805207
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/udp.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 48 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 14 |
7 files changed, 83 insertions, 67 deletions
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 87d1e0d0d..847d02982 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ebf8a2d04..2534069ab 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -261,31 +261,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -338,7 +341,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -432,6 +435,11 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..3db060384 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..cf1c5c433 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() return 0, nil, err } @@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 // destination address, route using that address. if !ep.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { + if !ip.IsValid(len(payloadBytes)) { ep.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..dd931f88c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dccb9a7eb..6ac7c067a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { |