summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2021-01-07 22:25:30 +0000
committergVisor bot <gvisor-bot@google.com>2021-01-07 22:25:30 +0000
commitf1f3952a222bb5bd930d8425d25e8fd9fb559e27 (patch)
treebed48bf1e34d27a19978fbfd3a5a9e4482dc1362
parent82aa687dfb9645a5d0994a97ae4d955e61685e2e (diff)
parentb1de1da318631c6d29f6c04dea370f712078f443 (diff)
Merge release-20201208.0-127-gb1de1da31 (automated)
-rw-r--r--pkg/sentry/socket/netstack/netstack.go265
-rw-r--r--pkg/sentry/socket/netstack/netstack_state_autogen.go34
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go57
-rw-r--r--pkg/tcpip/buffer/view.go37
-rw-r--r--pkg/tcpip/tcpip.go81
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go46
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go34
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go228
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go58
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
14 files changed, 470 insertions, 482 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index fe11fca9c..dcf898c0a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -28,9 +28,9 @@ import (
"bytes"
"fmt"
"io"
+ "io/ioutil"
"math"
"reflect"
- "sync/atomic"
"syscall"
"time"
@@ -43,7 +43,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -308,16 +307,8 @@ type socketOpsCommon struct {
skType linux.SockType
protocol int
- // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
- // Must be accessed using atomic operations. It must only be written
- // with readMu held but can be read without holding readMu. The latter
- // is required to avoid deadlocks in epoll Readiness checks.
- readViewHasData uint32
-
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
- // readView contains the remaining payload from the last packet.
- readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
readCM socket.IPControlMessages
@@ -336,8 +327,8 @@ type socketOpsCommon struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented at this level
- // because it takes into account data from readView.
+ // TODO(b/153685824): Move this to SocketOptions.
+ // sockOptInq corresponds to TCP_INQ.
sockOptInq bool
}
@@ -377,41 +368,23 @@ func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
-// fetchReadView updates the readView field of the socket if it's currently
-// empty. It assumes that the socket is locked.
-//
// Precondition: s.readMu must be held.
-func (s *socketOpsCommon) fetchReadView() *syserr.Error {
- if len(s.readView) > 0 {
- return nil
- }
- s.readView = nil
- s.sender = tcpip.FullAddress{}
- s.linkPacketInfo = tcpip.LinkPacketInfo{}
+func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
+ res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
+ Peek: peek,
+ NeedRemoteAddr: true,
+ NeedLinkPacketInfo: true,
+ })
- var v buffer.View
- var cms tcpip.ControlMessages
- var err *tcpip.Error
+ // Assign these anyways.
+ s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
+ s.sender = res.RemoteAddr
+ s.linkPacketInfo = res.LinkPacketInfo
- switch e := s.Endpoint.(type) {
- // The ordering of these interfaces matters. The most specific
- // interfaces must be specified before the more generic Endpoint
- // interface.
- case tcpip.PacketEndpoint:
- v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
- case tcpip.Endpoint:
- v, cms, err = e.Read(&s.sender)
- }
if err != nil {
- atomic.StoreUint32(&s.readViewHasData, 0)
- return syserr.TranslateNetstackError(err)
+ return 0, 0, syserr.TranslateNetstackError(err)
}
-
- s.readView = v
- s.readCM = socket.NewIPControlMessages(s.family, cms)
- atomic.StoreUint32(&s.readViewHasData, 1)
-
- return nil
+ return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release.
@@ -460,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// WriteTo implements fs.FileOperations.WriteTo.
func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
+ defer s.readMu.Unlock()
- // Copy as much data as possible.
- done := int64(0)
- for count > 0 {
- // This may return a blocking error.
- if err := s.fetchReadView(); err != nil {
- s.readMu.Unlock()
- return done, err.ToError()
- }
-
- // Write to the underlying file.
- n, err := dst.Write(s.readView)
- done += int64(n)
- count -= int64(n)
- if dup {
- // That's all we support for dup. This is generally
- // supported by any Linux system calls, but the
- // expectation is that now a caller will call read to
- // actually remove these bytes from the socket.
- break
- }
-
- // Drop that part of the view.
- s.readView.TrimFront(n)
- if err != nil {
- s.readMu.Unlock()
- return done, err
- }
+ // This may return a blocking error.
+ n, _, err := s.readLocked(dst, int(count), dup /* peek */)
+ if err != nil {
+ return 0, err.ToError()
}
-
- s.readMu.Unlock()
- return done, nil
+ return int64(n), nil
}
// ioSequencePayload implements tcpip.Payload.
@@ -627,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
// Readiness returns a mask of ready events for socket s.
func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
- r := s.Endpoint.Readiness(mask)
-
- // Check our cached value iff the caller asked for readability and the
- // endpoint itself is currently not readable.
- if (mask & ^r & waiter.EventIn) != 0 {
- if atomic.LoadUint32(&s.readViewHasData) == 1 {
- r |= waiter.EventIn
- }
- }
-
- return r
+ return s.Endpoint.Readiness(mask)
}
func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
@@ -2618,66 +2557,20 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return a, l, nil
}
-// coalescingRead is the fast path for non-blocking, non-peek, stream-based
-// case. It coalesces as many packets as possible before returning to the
-// caller.
+// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
- var err *syserr.Error
- var copied int
-
- // Copy as many views as possible into the user-provided buffer.
- for {
- // Always do at least one fetchReadView, even if the number of bytes to
- // read is 0.
- err = s.fetchReadView()
- if err != nil || len(s.readView) == 0 {
- break
- }
- if dst.NumBytes() == 0 {
- break
- }
-
- var n int
- var e error
- if discard {
- n = len(s.readView)
- if int64(n) > dst.NumBytes() {
- n = int(dst.NumBytes())
- }
- } else {
- n, e = dst.CopyOut(ctx, s.readView)
- // Set the control message, even if 0 bytes were read.
- if e == nil {
- s.updateTimestamp()
- }
- }
- copied += n
- s.readView.TrimFront(n)
-
- dst = dst.DropFirst(n)
- if e != nil {
- err = syserr.FromError(e)
- break
- }
- // If we are done reading requested data then stop.
- if dst.NumBytes() == 0 {
- break
- }
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
+ // Always do at least one read, even if the number of bytes to read is 0.
+ var n int
+ n, _, err := s.readLocked(dst, count, false /* peek */)
+ if err != nil {
+ return 0, err
}
-
- // If we managed to copy something, we must deliver it.
- if copied > 0 {
- s.Endpoint.ModerateRecvBuf(copied)
- return copied, nil
+ if n > 0 {
+ s.Endpoint.ModerateRecvBuf(n)
}
-
- return 0, err
+ return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
@@ -2689,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
return
}
cmsg.IP.HasInq = true
- cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+ cmsg.IP.Inq = int32(rcvBufUsed)
}
func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
@@ -2726,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
- n, err := s.coalescingRead(ctx, dst, trunc)
+
+ var w io.Writer
+ if trunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
+
+ if err == nil && !trunc {
+ // Set the control message, even if 0 bytes were read.
+ s.updateTimestamp()
+ }
+
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
@@ -2736,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- if err := s.fetchReadView(); err != nil {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read, and does not write to buffer.
+ isTCPPeekTrunc := !isPacket && peek && trunc
+
+ var w io.Writer
+ if isTCPPeekTrunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ var numRead, numTotal int
+ var err *syserr.Error
+ numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err
}
- if !isPacket && peek && trunc {
- // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
- // amount that could be read.
+ if isTCPPeekTrunc {
+ // TCP endpoint does not return the total bytes in buffer as numTotal.
+ // We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
- available := len(s.readView) + int(rql)
+ available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
@@ -2755,11 +2676,9 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
- n, err := dst.CopyOut(ctx, s.readView)
// Set the control message, even if 0 bytes were read.
- if err == nil {
- s.updateTimestamp()
- }
+ s.updateTimestamp()
+
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
@@ -2772,58 +2691,33 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
}
if peek {
- if l := len(s.readView); trunc && l > n {
+ if trunc && numTotal > numRead {
// isPacket must be true.
- return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
}
-
- if isPacket || err != nil {
- return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
- }
-
- // We need to peek beyond the first message.
- dst = dst.DropFirst(n)
- num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, err := s.Endpoint.Peek(dsts)
- // TODO(b/78348848): Handle peek timestamp.
- if err != nil {
- return int64(n), syserr.TranslateNetstackError(err).ToError()
- }
- return int64(n), nil
- }})
- n += int(num)
- if err == syserror.ErrWouldBlock && n > 0 {
- // We got some data, so no need to return an error.
- err = nil
- }
- return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ return numRead, 0, nil, 0, s.controlMessages(), nil
}
var msgLen int
if isPacket {
- msgLen = len(s.readView)
- s.readView = nil
+ msgLen = numTotal
} else {
- msgLen = int(n)
- s.readView.TrimFront(int(n))
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+ msgLen = numRead
}
var flags int
- if msgLen > int(n) {
+ if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
+ n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
- return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+ return n, flags, addr, addrLen, cmsg, nil
}
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
@@ -3090,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, syserr.TranslateNetstackError(terr).ToError()
}
- // Add bytes removed from the endpoint but not yet sent to the caller.
- s.readMu.Lock()
- v += len(s.readView)
- s.readMu.Unlock()
-
if v > math.MaxInt32 {
v = math.MaxInt32
}
diff --git a/pkg/sentry/socket/netstack/netstack_state_autogen.go b/pkg/sentry/socket/netstack/netstack_state_autogen.go
index 8465d8743..9925e2e9e 100644
--- a/pkg/sentry/socket/netstack/netstack_state_autogen.go
+++ b/pkg/sentry/socket/netstack/netstack_state_autogen.go
@@ -41,8 +41,6 @@ func (s *socketOpsCommon) StateFields() []string {
"Endpoint",
"skType",
"protocol",
- "readViewHasData",
- "readView",
"readCM",
"sender",
"linkPacketInfo",
@@ -63,15 +61,13 @@ func (s *socketOpsCommon) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(3, &s.Endpoint)
stateSinkObject.Save(4, &s.skType)
stateSinkObject.Save(5, &s.protocol)
- stateSinkObject.Save(6, &s.readViewHasData)
- stateSinkObject.Save(7, &s.readView)
- stateSinkObject.Save(8, &s.readCM)
- stateSinkObject.Save(9, &s.sender)
- stateSinkObject.Save(10, &s.linkPacketInfo)
- stateSinkObject.Save(11, &s.sockOptTimestamp)
- stateSinkObject.Save(12, &s.timestampValid)
- stateSinkObject.Save(13, &s.timestampNS)
- stateSinkObject.Save(14, &s.sockOptInq)
+ stateSinkObject.Save(6, &s.readCM)
+ stateSinkObject.Save(7, &s.sender)
+ stateSinkObject.Save(8, &s.linkPacketInfo)
+ stateSinkObject.Save(9, &s.sockOptTimestamp)
+ stateSinkObject.Save(10, &s.timestampValid)
+ stateSinkObject.Save(11, &s.timestampNS)
+ stateSinkObject.Save(12, &s.sockOptInq)
}
func (s *socketOpsCommon) afterLoad() {}
@@ -83,15 +79,13 @@ func (s *socketOpsCommon) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(3, &s.Endpoint)
stateSourceObject.Load(4, &s.skType)
stateSourceObject.Load(5, &s.protocol)
- stateSourceObject.Load(6, &s.readViewHasData)
- stateSourceObject.Load(7, &s.readView)
- stateSourceObject.Load(8, &s.readCM)
- stateSourceObject.Load(9, &s.sender)
- stateSourceObject.Load(10, &s.linkPacketInfo)
- stateSourceObject.Load(11, &s.sockOptTimestamp)
- stateSourceObject.Load(12, &s.timestampValid)
- stateSourceObject.Load(13, &s.timestampNS)
- stateSourceObject.Load(14, &s.sockOptInq)
+ stateSourceObject.Load(6, &s.readCM)
+ stateSourceObject.Load(7, &s.sender)
+ stateSourceObject.Load(8, &s.linkPacketInfo)
+ stateSourceObject.Load(9, &s.sockOptTimestamp)
+ stateSourceObject.Load(10, &s.timestampValid)
+ stateSourceObject.Load(11, &s.timestampNS)
+ stateSourceObject.Load(12, &s.sockOptInq)
}
func (s *SocketVFS2) StateTypeName() string {
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 77c3c110c..2756d4471 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -48,6 +48,7 @@ var (
ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
+ ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT)
)
var netstackErrorTranslations map[string]*Error
@@ -100,6 +101,7 @@ func init() {
addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
+ addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer)
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 4f551cd92..7193f56ad 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -286,45 +286,47 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
default:
}
- read, _, err := ep.Read(addr)
+ w := tcpip.SliceWriter(b)
+ opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
+ res, err := ep.Read(&w, len(b), opts)
if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, _, err = ep.Read(addr)
+ res, err = ep.Read(&w, len(b), opts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
+ return 0, io.EOF
}
if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
+ return 0, errorer.newOpError("read", errors.New(err.String()))
}
- return read, nil
+ if addr != nil {
+ *addr = res.RemoteAddr
+ }
+ return res.Count, nil
}
// Read implements net.Conn.Read.
@@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- numRead := 0
- defer func() {
- if numRead != 0 {
- c.ep.ModerateRecvBuf(numRead)
- }
- }()
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
+ n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
+ if n != 0 {
+ c.ep.ModerateRecvBuf(n)
}
- return numRead, nil
+ return n, err
}
// Write implements net.Conn.Write.
@@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
-
- return copy(b, read), fullToUDPAddr(addr), nil
+ return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8db70a700..5dd1b1b6b 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) {
}
// Read implements io.Reader.
-func (vv *VectorisedView) Read(v View) (copied int, err error) {
- count := len(v)
+func (vv *VectorisedView) Read(b []byte) (copied int, err error) {
+ count := len(b)
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
- copy(v[copied:], vv.views[0][:count])
+ copy(b[copied:], vv.views[0][:count])
vv.views[0].TrimFront(count)
copied += count
return copied, nil
}
count -= len(vv.views[0])
- copy(v[copied:], vv.views[0])
+ copy(b[copied:], vv.views[0])
copied += len(vv.views[0])
vv.removeFirst()
}
@@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
return copied
}
+// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
+// unless peek is true.
+func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+ var err error
+ done := 0
+ for _, v := range vv.Views() {
+ remaining := count - done
+ if remaining <= 0 {
+ break
+ }
+ if len(v) > remaining {
+ v = v[:remaining]
+ }
+
+ var n int
+ n, err = dst.Write(v)
+ if n > 0 {
+ done += n
+ }
+ if err != nil {
+ break
+ }
+ }
+ if !peek {
+ vv.TrimFront(done)
+ }
+ return done, err
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ef0f51f1a..f798056c0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "io"
"math/bits"
"reflect"
"strconv"
@@ -39,7 +40,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -113,6 +113,7 @@ var (
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
ErrMalformedHeader = &Error{msg: "header is malformed"}
+ ErrBadBuffer = &Error{msg: "bad buffer"}
)
var messageToError map[string]*Error
@@ -162,6 +163,7 @@ func StringToError(s string) *Error {
ErrNotPermitted,
ErrAddressFamilyNotSupported,
ErrMalformedHeader,
+ ErrBadBuffer,
}
messageToError = make(map[string]*Error)
@@ -496,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) {
return s[:size], nil
}
+var _ io.Writer = (*SliceWriter)(nil)
+
+// SliceWriter implements io.Writer for slices.
+type SliceWriter []byte
+
+// Write implements io.Writer.Write.
+func (s *SliceWriter) Write(b []byte) (int, error) {
+ n := copy(*s, b)
+ *s = (*s)[n:]
+ if n < len(b) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
@@ -552,6 +569,40 @@ type PacketOwner interface {
GID() uint32
}
+// ReadOptions contains options for Endpoint.Read.
+type ReadOptions struct {
+ // Peek indicates whether this read is a peek.
+ Peek bool
+
+ // NeedRemoteAddr indicates whether to return the remote address, if
+ // supported.
+ NeedRemoteAddr bool
+
+ // NeedLinkPacketInfo indicates whether to return the link-layer information,
+ // if supported.
+ NeedLinkPacketInfo bool
+}
+
+// ReadResult represents result for a successful Endpoint.Read.
+type ReadResult struct {
+ // Count is the number of bytes received and written to the buffer.
+ Count int
+
+ // Total is the number of bytes of the received packet. This can be used to
+ // determine whether the read is truncated.
+ Total int
+
+ // ControlMessages is the control messages received.
+ ControlMessages ControlMessages
+
+ // RemoteAddr is the remote address if ReadOptions.NeedAddr is true.
+ RemoteAddr FullAddress
+
+ // LinkPacketInfo is the link-layer information of the received packet if
+ // ReadOptions.NeedLinkPacketInfo is true.
+ LinkPacketInfo LinkPacketInfo
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -566,11 +617,15 @@ type Endpoint interface {
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
- // Read reads data from the endpoint and optionally returns the sender.
+ // Read reads data from the endpoint and optionally writes to dst.
+ //
+ // This method does not block if there is no data pending; in this case,
+ // ErrWouldBlock is returned.
//
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+ // If non-zero number of bytes are successfully read and written to dst, err
+ // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
+ // should be returned.
+ Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -592,11 +647,6 @@ type Endpoint interface {
// not). The channel is only non-nil in this case.
Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
- // Peek reads data without consuming it from the endpoint.
- //
- // This method does not block if there is no data pending.
- Peek([][]byte) (int64, *Error)
-
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
@@ -703,17 +753,6 @@ type LinkPacketInfo struct {
PktType PacketType
}
-// PacketEndpoint are additional methods that are only implemented by Packet
-// endpoints.
-type PacketEndpoint interface {
- // ReadPacket reads a datagram/packet from the endpoint and optionally
- // returns the sender and additional LinkPacketInfo.
- //
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
-}
-
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d1e4a7cb7..2eb4457df 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return nil
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e5e247342..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -160,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7befcfc9b..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6e3c8860e..8f3981075 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -393,15 +393,28 @@ type endpoint struct {
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
e.UnlockUser()
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return nil, nil, err
}
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ return nil, nil, tcpip.ErrNotConnected
}
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
- }
- return v, tcpip.ControlMessages{}, err
-}
-
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
+
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
- var delta int
- if s.viewToDeliver >= len(views) {
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return queueAndSend()
}
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, e.hardErrorLocked()
- }
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ErrInvalidEndpointState
- }
-
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ErrClosedForReceive
- }
- return 0, tcpip.ErrWouldBlock
- }
-
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
-
- return num, nil
-}
-
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvListMu must be held.
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 5ef73ec74..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return SegSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 5922083a9..aab92b94f 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -595,7 +595,6 @@ func (s *segment) StateFields() []string {
"remoteLinkAddr",
"data",
"hdr",
- "viewToDeliver",
"sequenceNumber",
"ackNumber",
"flags",
@@ -609,6 +608,7 @@ func (s *segment) StateFields() []string {
"xmitTime",
"xmitCount",
"acked",
+ "dataMemSize",
}
}
@@ -619,11 +619,11 @@ func (s *segment) StateSave(stateSinkObject state.Sink) {
var dataValue buffer.VectorisedView = s.saveData()
stateSinkObject.SaveValue(9, dataValue)
var optionsValue []byte = s.saveOptions()
- stateSinkObject.SaveValue(19, optionsValue)
+ stateSinkObject.SaveValue(18, optionsValue)
var rcvdTimeValue unixTime = s.saveRcvdTime()
- stateSinkObject.SaveValue(21, rcvdTimeValue)
+ stateSinkObject.SaveValue(20, rcvdTimeValue)
var xmitTimeValue unixTime = s.saveXmitTime()
- stateSinkObject.SaveValue(22, xmitTimeValue)
+ stateSinkObject.SaveValue(21, xmitTimeValue)
stateSinkObject.Save(0, &s.segmentEntry)
stateSinkObject.Save(1, &s.refCnt)
stateSinkObject.Save(2, &s.ep)
@@ -634,17 +634,17 @@ func (s *segment) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(7, &s.nicID)
stateSinkObject.Save(8, &s.remoteLinkAddr)
stateSinkObject.Save(10, &s.hdr)
- stateSinkObject.Save(11, &s.viewToDeliver)
- stateSinkObject.Save(12, &s.sequenceNumber)
- stateSinkObject.Save(13, &s.ackNumber)
- stateSinkObject.Save(14, &s.flags)
- stateSinkObject.Save(15, &s.window)
- stateSinkObject.Save(16, &s.csum)
- stateSinkObject.Save(17, &s.csumValid)
- stateSinkObject.Save(18, &s.parsedOptions)
- stateSinkObject.Save(20, &s.hasNewSACKInfo)
- stateSinkObject.Save(23, &s.xmitCount)
- stateSinkObject.Save(24, &s.acked)
+ stateSinkObject.Save(11, &s.sequenceNumber)
+ stateSinkObject.Save(12, &s.ackNumber)
+ stateSinkObject.Save(13, &s.flags)
+ stateSinkObject.Save(14, &s.window)
+ stateSinkObject.Save(15, &s.csum)
+ stateSinkObject.Save(16, &s.csumValid)
+ stateSinkObject.Save(17, &s.parsedOptions)
+ stateSinkObject.Save(19, &s.hasNewSACKInfo)
+ stateSinkObject.Save(22, &s.xmitCount)
+ stateSinkObject.Save(23, &s.acked)
+ stateSinkObject.Save(24, &s.dataMemSize)
}
func (s *segment) afterLoad() {}
@@ -660,21 +660,21 @@ func (s *segment) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(7, &s.nicID)
stateSourceObject.Load(8, &s.remoteLinkAddr)
stateSourceObject.Load(10, &s.hdr)
- stateSourceObject.Load(11, &s.viewToDeliver)
- stateSourceObject.Load(12, &s.sequenceNumber)
- stateSourceObject.Load(13, &s.ackNumber)
- stateSourceObject.Load(14, &s.flags)
- stateSourceObject.Load(15, &s.window)
- stateSourceObject.Load(16, &s.csum)
- stateSourceObject.Load(17, &s.csumValid)
- stateSourceObject.Load(18, &s.parsedOptions)
- stateSourceObject.Load(20, &s.hasNewSACKInfo)
- stateSourceObject.Load(23, &s.xmitCount)
- stateSourceObject.Load(24, &s.acked)
+ stateSourceObject.Load(11, &s.sequenceNumber)
+ stateSourceObject.Load(12, &s.ackNumber)
+ stateSourceObject.Load(13, &s.flags)
+ stateSourceObject.Load(14, &s.window)
+ stateSourceObject.Load(15, &s.csum)
+ stateSourceObject.Load(16, &s.csumValid)
+ stateSourceObject.Load(17, &s.parsedOptions)
+ stateSourceObject.Load(19, &s.hasNewSACKInfo)
+ stateSourceObject.Load(22, &s.xmitCount)
+ stateSourceObject.Load(23, &s.acked)
+ stateSourceObject.Load(24, &s.dataMemSize)
stateSourceObject.LoadValue(9, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
- stateSourceObject.LoadValue(19, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
- stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
- stateSourceObject.LoadValue(22, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(18, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
+ stateSourceObject.LoadValue(20, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
}
func (q *segmentQueue) StateTypeName() string {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e8bd8b04..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
@@ -282,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
@@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
}
- return p.data.ToView(), cm, nil
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()