summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go265
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go57
-rw-r--r--pkg/tcpip/buffer/view.go37
-rw-r--r--pkg/tcpip/buffer/view_test.go68
-rw-r--r--pkg/tcpip/checker/checker.go12
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go20
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go34
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go7
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go3
-rw-r--r--pkg/tcpip/stack/transport_test.go9
-rw-r--r--pkg/tcpip/tcpip.go81
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go28
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go36
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go31
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go60
-rw-r--r--pkg/tcpip/tests/integration/route_test.go69
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go46
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go34
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go228
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go272
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go23
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go31
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc16
-rw-r--r--test/syscalls/linux/socket_generic.cc9
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc16
35 files changed, 884 insertions, 729 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index fae3b6783..b2206900b 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -25,7 +25,6 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/metric",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index fe11fca9c..dcf898c0a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -28,9 +28,9 @@ import (
"bytes"
"fmt"
"io"
+ "io/ioutil"
"math"
"reflect"
- "sync/atomic"
"syscall"
"time"
@@ -43,7 +43,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -308,16 +307,8 @@ type socketOpsCommon struct {
skType linux.SockType
protocol int
- // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
- // Must be accessed using atomic operations. It must only be written
- // with readMu held but can be read without holding readMu. The latter
- // is required to avoid deadlocks in epoll Readiness checks.
- readViewHasData uint32
-
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
- // readView contains the remaining payload from the last packet.
- readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
readCM socket.IPControlMessages
@@ -336,8 +327,8 @@ type socketOpsCommon struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented at this level
- // because it takes into account data from readView.
+ // TODO(b/153685824): Move this to SocketOptions.
+ // sockOptInq corresponds to TCP_INQ.
sockOptInq bool
}
@@ -377,41 +368,23 @@ func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
-// fetchReadView updates the readView field of the socket if it's currently
-// empty. It assumes that the socket is locked.
-//
// Precondition: s.readMu must be held.
-func (s *socketOpsCommon) fetchReadView() *syserr.Error {
- if len(s.readView) > 0 {
- return nil
- }
- s.readView = nil
- s.sender = tcpip.FullAddress{}
- s.linkPacketInfo = tcpip.LinkPacketInfo{}
+func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
+ res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
+ Peek: peek,
+ NeedRemoteAddr: true,
+ NeedLinkPacketInfo: true,
+ })
- var v buffer.View
- var cms tcpip.ControlMessages
- var err *tcpip.Error
+ // Assign these anyways.
+ s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
+ s.sender = res.RemoteAddr
+ s.linkPacketInfo = res.LinkPacketInfo
- switch e := s.Endpoint.(type) {
- // The ordering of these interfaces matters. The most specific
- // interfaces must be specified before the more generic Endpoint
- // interface.
- case tcpip.PacketEndpoint:
- v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
- case tcpip.Endpoint:
- v, cms, err = e.Read(&s.sender)
- }
if err != nil {
- atomic.StoreUint32(&s.readViewHasData, 0)
- return syserr.TranslateNetstackError(err)
+ return 0, 0, syserr.TranslateNetstackError(err)
}
-
- s.readView = v
- s.readCM = socket.NewIPControlMessages(s.family, cms)
- atomic.StoreUint32(&s.readViewHasData, 1)
-
- return nil
+ return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release.
@@ -460,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// WriteTo implements fs.FileOperations.WriteTo.
func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
+ defer s.readMu.Unlock()
- // Copy as much data as possible.
- done := int64(0)
- for count > 0 {
- // This may return a blocking error.
- if err := s.fetchReadView(); err != nil {
- s.readMu.Unlock()
- return done, err.ToError()
- }
-
- // Write to the underlying file.
- n, err := dst.Write(s.readView)
- done += int64(n)
- count -= int64(n)
- if dup {
- // That's all we support for dup. This is generally
- // supported by any Linux system calls, but the
- // expectation is that now a caller will call read to
- // actually remove these bytes from the socket.
- break
- }
-
- // Drop that part of the view.
- s.readView.TrimFront(n)
- if err != nil {
- s.readMu.Unlock()
- return done, err
- }
+ // This may return a blocking error.
+ n, _, err := s.readLocked(dst, int(count), dup /* peek */)
+ if err != nil {
+ return 0, err.ToError()
}
-
- s.readMu.Unlock()
- return done, nil
+ return int64(n), nil
}
// ioSequencePayload implements tcpip.Payload.
@@ -627,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
// Readiness returns a mask of ready events for socket s.
func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
- r := s.Endpoint.Readiness(mask)
-
- // Check our cached value iff the caller asked for readability and the
- // endpoint itself is currently not readable.
- if (mask & ^r & waiter.EventIn) != 0 {
- if atomic.LoadUint32(&s.readViewHasData) == 1 {
- r |= waiter.EventIn
- }
- }
-
- return r
+ return s.Endpoint.Readiness(mask)
}
func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
@@ -2618,66 +2557,20 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return a, l, nil
}
-// coalescingRead is the fast path for non-blocking, non-peek, stream-based
-// case. It coalesces as many packets as possible before returning to the
-// caller.
+// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
- var err *syserr.Error
- var copied int
-
- // Copy as many views as possible into the user-provided buffer.
- for {
- // Always do at least one fetchReadView, even if the number of bytes to
- // read is 0.
- err = s.fetchReadView()
- if err != nil || len(s.readView) == 0 {
- break
- }
- if dst.NumBytes() == 0 {
- break
- }
-
- var n int
- var e error
- if discard {
- n = len(s.readView)
- if int64(n) > dst.NumBytes() {
- n = int(dst.NumBytes())
- }
- } else {
- n, e = dst.CopyOut(ctx, s.readView)
- // Set the control message, even if 0 bytes were read.
- if e == nil {
- s.updateTimestamp()
- }
- }
- copied += n
- s.readView.TrimFront(n)
-
- dst = dst.DropFirst(n)
- if e != nil {
- err = syserr.FromError(e)
- break
- }
- // If we are done reading requested data then stop.
- if dst.NumBytes() == 0 {
- break
- }
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
+ // Always do at least one read, even if the number of bytes to read is 0.
+ var n int
+ n, _, err := s.readLocked(dst, count, false /* peek */)
+ if err != nil {
+ return 0, err
}
-
- // If we managed to copy something, we must deliver it.
- if copied > 0 {
- s.Endpoint.ModerateRecvBuf(copied)
- return copied, nil
+ if n > 0 {
+ s.Endpoint.ModerateRecvBuf(n)
}
-
- return 0, err
+ return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
@@ -2689,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
return
}
cmsg.IP.HasInq = true
- cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+ cmsg.IP.Inq = int32(rcvBufUsed)
}
func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
@@ -2726,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
- n, err := s.coalescingRead(ctx, dst, trunc)
+
+ var w io.Writer
+ if trunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
+
+ if err == nil && !trunc {
+ // Set the control message, even if 0 bytes were read.
+ s.updateTimestamp()
+ }
+
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
@@ -2736,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- if err := s.fetchReadView(); err != nil {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read, and does not write to buffer.
+ isTCPPeekTrunc := !isPacket && peek && trunc
+
+ var w io.Writer
+ if isTCPPeekTrunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ var numRead, numTotal int
+ var err *syserr.Error
+ numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err
}
- if !isPacket && peek && trunc {
- // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
- // amount that could be read.
+ if isTCPPeekTrunc {
+ // TCP endpoint does not return the total bytes in buffer as numTotal.
+ // We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
- available := len(s.readView) + int(rql)
+ available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
@@ -2755,11 +2676,9 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
- n, err := dst.CopyOut(ctx, s.readView)
// Set the control message, even if 0 bytes were read.
- if err == nil {
- s.updateTimestamp()
- }
+ s.updateTimestamp()
+
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
@@ -2772,58 +2691,33 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
}
if peek {
- if l := len(s.readView); trunc && l > n {
+ if trunc && numTotal > numRead {
// isPacket must be true.
- return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
}
-
- if isPacket || err != nil {
- return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
- }
-
- // We need to peek beyond the first message.
- dst = dst.DropFirst(n)
- num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, err := s.Endpoint.Peek(dsts)
- // TODO(b/78348848): Handle peek timestamp.
- if err != nil {
- return int64(n), syserr.TranslateNetstackError(err).ToError()
- }
- return int64(n), nil
- }})
- n += int(num)
- if err == syserror.ErrWouldBlock && n > 0 {
- // We got some data, so no need to return an error.
- err = nil
- }
- return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ return numRead, 0, nil, 0, s.controlMessages(), nil
}
var msgLen int
if isPacket {
- msgLen = len(s.readView)
- s.readView = nil
+ msgLen = numTotal
} else {
- msgLen = int(n)
- s.readView.TrimFront(int(n))
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+ msgLen = numRead
}
var flags int
- if msgLen > int(n) {
+ if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
+ n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
- return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+ return n, flags, addr, addrLen, cmsg, nil
}
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
@@ -3090,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, syserr.TranslateNetstackError(terr).ToError()
}
- // Add bytes removed from the endpoint but not yet sent to the caller.
- s.readMu.Lock()
- v += len(s.readView)
- s.readMu.Unlock()
-
if v > math.MaxInt32 {
v = math.MaxInt32
}
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 77c3c110c..2756d4471 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -48,6 +48,7 @@ var (
ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM)
+ ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT)
)
var netstackErrorTranslations map[string]*Error
@@ -100,6 +101,7 @@ func init() {
addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled)
addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet)
addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported)
+ addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer)
}
// TranslateNetstackError converts an error from the tcpip package to a sentry
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 4f551cd92..7193f56ad 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -286,45 +286,47 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
default:
}
- read, _, err := ep.Read(addr)
+ w := tcpip.SliceWriter(b)
+ opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
+ res, err := ep.Read(&w, len(b), opts)
if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, _, err = ep.Read(addr)
+ res, err = ep.Read(&w, len(b), opts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
+ return 0, io.EOF
}
if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
+ return 0, errorer.newOpError("read", errors.New(err.String()))
}
- return read, nil
+ if addr != nil {
+ *addr = res.RemoteAddr
+ }
+ return res.Count, nil
}
// Read implements net.Conn.Read.
@@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- numRead := 0
- defer func() {
- if numRead != 0 {
- c.ep.ModerateRecvBuf(numRead)
- }
- }()
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
+ n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
+ if n != 0 {
+ c.ep.ModerateRecvBuf(n)
}
- return numRead, nil
+ return n, err
}
// Write implements net.Conn.Write.
@@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
-
- return copy(b, read), fullToUDPAddr(addr), nil
+ return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8db70a700..5dd1b1b6b 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) {
}
// Read implements io.Reader.
-func (vv *VectorisedView) Read(v View) (copied int, err error) {
- count := len(v)
+func (vv *VectorisedView) Read(b []byte) (copied int, err error) {
+ count := len(b)
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
- copy(v[copied:], vv.views[0][:count])
+ copy(b[copied:], vv.views[0][:count])
vv.views[0].TrimFront(count)
copied += count
return copied, nil
}
count -= len(vv.views[0])
- copy(v[copied:], vv.views[0])
+ copy(b[copied:], vv.views[0])
copied += len(vv.views[0])
vv.removeFirst()
}
@@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
return copied
}
+// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
+// unless peek is true.
+func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+ var err error
+ done := 0
+ for _, v := range vv.Views() {
+ remaining := count - done
+ if remaining <= 0 {
+ break
+ }
+ if len(v) > remaining {
+ v = v[:remaining]
+ }
+
+ var n int
+ n, err = dst.Write(v)
+ if n > 0 {
+ done += n
+ }
+ if err != nil {
+ break
+ }
+ }
+ if !peek {
+ vv.TrimFront(done)
+ }
+ return done, err
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index 726e54de9..e0ef8a94d 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -235,14 +235,16 @@ func TestToClone(t *testing.T) {
}
}
-func TestVVReadToVV(t *testing.T) {
- testCases := []struct {
- comment string
- vv VectorisedView
- bytesToRead int
- wantBytes string
- leftVV VectorisedView
- }{
+type readToTestCases struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+}
+
+func createReadToTestCases() []readToTestCases {
+ return []readToTestCases{
{
comment: "large VV, short read",
vv: vv(30, "012345678901234567890123456789"),
@@ -279,8 +281,10 @@ func TestVVReadToVV(t *testing.T) {
leftVV: vv(0, ""),
},
}
+}
- for _, tc := range testCases {
+func TestVVReadToVV(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
var readTo VectorisedView
inSize := tc.vv.Size()
@@ -301,6 +305,52 @@ func TestVVReadToVV(t *testing.T) {
}
}
+func TestVVReadTo(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ if got, want := tc.vv.Size(), origSize-copied; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
+func TestVVReadToPeek(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ origData := string(tc.vv.ToView())
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ // Expect tc.vv is unchanged.
+ if got, want := tc.vv.Size(), origSize; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), origData; got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func TestVVRead(t *testing.T) {
testCases := []struct {
comment string
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 91971b687..0ac2000ca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1603,3 +1603,15 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
}
}
}
+
+// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
+func IgnoreCmpPath(paths ...string) cmp.Option {
+ ignores := map[string]struct{}{}
+ for _, path := range paths {
+ ignores[path] = struct{}{}
+ }
+ return cmp.FilterPath(func(path cmp.Path) bool {
+ _, ok := ignores[path.String()]
+ return ok
+ }, cmp.Ignore())
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index ef62fe6fc..1c4919b1e 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,11 @@
package ipv4_test
import (
+ "bytes"
"context"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
"net"
"testing"
@@ -2408,18 +2410,26 @@ func TestReceiveFragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, expectedPayload := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
}
- if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(expectedPayload),
+ Total: len(expectedPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
+ }
+ if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 5f07d3af8..360025b20 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,8 +15,10 @@
package ipv6
import (
+ "bytes"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
"net"
"testing"
@@ -844,13 +846,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
},
}
+ const mtu = header.IPv6MinimumMTU
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
+ e := channel.New(1, mtu, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -979,17 +982,24 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if got := stats.Value(); got != 1 {
t.Errorf("got UDP Rx Packets = %d, want = 1", got)
}
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("Read: %s", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(udpPayload),
+ Total: len(udpPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
}
- if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
// Should not have any more UDP packets.
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -1969,18 +1979,20 @@ func TestReceiveIPv6Fragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, p := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
}
- if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ if diff := cmp.Diff(p, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 51d428049..4777163cd 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -44,6 +44,7 @@ import (
"bufio"
"fmt"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -200,7 +201,7 @@ func main() {
// connection from its side.
wq.EventRegister(&waitEntry, waiter.EventIn)
for {
- v, _, err := ep.Read(nil)
+ _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrClosedForReceive {
break
@@ -213,8 +214,6 @@ func main() {
log.Fatal("Read() failed:", err)
}
-
- os.Stdout.Write(v)
}
wq.EventUnregister(&waitEntry)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1c2afd554..a80fa0474 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,8 +20,10 @@
package main
import (
+ "bytes"
"flag"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -54,7 +56,8 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
defer wq.EventUnregister(&waitEntry)
for {
- v, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
@@ -64,7 +67,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
return
}
- ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{})
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 737d8d912..859278f0b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io/ioutil"
"math"
"math/rand"
"testing"
@@ -351,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
ep := <-pollChannel
- if _, _, err := ep.Read(nil); err != nil {
+ if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index dd552b8b9..a5facf578 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -85,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ return tcpip.ReadResult{}, nil
}
func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -110,10 +111,6 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// SetSockOpt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ef0f51f1a..f798056c0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "io"
"math/bits"
"reflect"
"strconv"
@@ -39,7 +40,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -113,6 +113,7 @@ var (
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
ErrMalformedHeader = &Error{msg: "header is malformed"}
+ ErrBadBuffer = &Error{msg: "bad buffer"}
)
var messageToError map[string]*Error
@@ -162,6 +163,7 @@ func StringToError(s string) *Error {
ErrNotPermitted,
ErrAddressFamilyNotSupported,
ErrMalformedHeader,
+ ErrBadBuffer,
}
messageToError = make(map[string]*Error)
@@ -496,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) {
return s[:size], nil
}
+var _ io.Writer = (*SliceWriter)(nil)
+
+// SliceWriter implements io.Writer for slices.
+type SliceWriter []byte
+
+// Write implements io.Writer.Write.
+func (s *SliceWriter) Write(b []byte) (int, error) {
+ n := copy(*s, b)
+ *s = (*s)[n:]
+ if n < len(b) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
@@ -552,6 +569,40 @@ type PacketOwner interface {
GID() uint32
}
+// ReadOptions contains options for Endpoint.Read.
+type ReadOptions struct {
+ // Peek indicates whether this read is a peek.
+ Peek bool
+
+ // NeedRemoteAddr indicates whether to return the remote address, if
+ // supported.
+ NeedRemoteAddr bool
+
+ // NeedLinkPacketInfo indicates whether to return the link-layer information,
+ // if supported.
+ NeedLinkPacketInfo bool
+}
+
+// ReadResult represents result for a successful Endpoint.Read.
+type ReadResult struct {
+ // Count is the number of bytes received and written to the buffer.
+ Count int
+
+ // Total is the number of bytes of the received packet. This can be used to
+ // determine whether the read is truncated.
+ Total int
+
+ // ControlMessages is the control messages received.
+ ControlMessages ControlMessages
+
+ // RemoteAddr is the remote address if ReadOptions.NeedAddr is true.
+ RemoteAddr FullAddress
+
+ // LinkPacketInfo is the link-layer information of the received packet if
+ // ReadOptions.NeedLinkPacketInfo is true.
+ LinkPacketInfo LinkPacketInfo
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -566,11 +617,15 @@ type Endpoint interface {
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
- // Read reads data from the endpoint and optionally returns the sender.
+ // Read reads data from the endpoint and optionally writes to dst.
+ //
+ // This method does not block if there is no data pending; in this case,
+ // ErrWouldBlock is returned.
//
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+ // If non-zero number of bytes are successfully read and written to dst, err
+ // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
+ // should be returned.
+ Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -592,11 +647,6 @@ type Endpoint interface {
// not). The channel is only non-nil in this case.
Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
- // Peek reads data without consuming it from the endpoint.
- //
- // This method does not block if there is no data pending.
- Peek([][]byte) (int64, *Error)
-
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
@@ -703,17 +753,6 @@ type LinkPacketInfo struct {
PktType PacketType
}
-// PacketEndpoint are additional methods that are only implemented by Packet
-// endpoints.
-type PacketEndpoint interface {
- // ReadPacket reads a datagram/packet from the endpoint and optionally
- // returns the sender and additional LinkPacketInfo.
- //
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
-}
-
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index bb3b2ed0d..ca1e88e99 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -15,6 +15,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 907565ac4..60054d6ef 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -15,12 +15,13 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
@@ -382,24 +383,33 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(data), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ RemoteAddr: tcpip.FullAddress{Addr: expectedFrom},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != expectedFrom {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
}
- return addr
+ return res.RemoteAddr
}
addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index b41b72381..209da3903 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -15,12 +15,13 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
@@ -86,21 +87,21 @@ func TestPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
- icmpBuf func(*testing.T) buffer.View
+ icmpBuf func(*testing.T) []byte
}{
{
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
hdr.SetType(header.ICMPv4Echo)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
{
@@ -108,14 +109,14 @@ func TestPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
hdr.SetType(header.ICMPv6EchoRequest)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
}
@@ -200,16 +201,25 @@ func TestPing(t *testing.T) {
// Wait for the endpoint to be readable.
<-waiterCH
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(icmpBuf), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.remoteAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
})
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index baaa741cd..cf9e86c3c 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -15,12 +15,14 @@
package integration_test
import (
+ "bytes"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -238,21 +240,28 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
}
- var addr tcpip.FullAddress
- if gotPayload, _, err := rep.Read(&addr); test.expectRx {
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ if res, err := rep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("reep.Read(_): %s", err)
- }
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if addr.Addr != test.addAddress.AddressWithPrefix.Address {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.addAddress.AddressWithPrefix.Address,
+ },
+ }, res,
+ checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"),
+ ); diff != "" {
+ t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 20f8a7e6c..fae6c256a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -15,12 +15,14 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -462,17 +464,23 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
}
test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
- if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ var buf bytes.Buffer
+ var opts tcpip.ReadOptions
+ if res, err := ep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
@@ -589,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
// Wait for the endpoint to become readable.
<-rep.ch
- if gotPayload, _, err := rep.ep.Read(nil); err != nil {
- t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
- } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ var buf bytes.Buffer
+ result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
+ continue
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
+ }
+ if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
}
}
@@ -719,10 +737,20 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
}
test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
- if gotPayload, _, err := ep.Read(nil); err != nil {
- t.Fatalf("ep.Read(nil): %s", err)
- } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read: %s", err)
+ } else {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
}
// We should not receive UDP packets to the group once we leave
@@ -731,8 +759,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
if err := ep.SetSockOpt(&removeOpt); err != nil {
t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got ep.Read(nil) = (%x, _, %s), want = (nil, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 02fc47015..52cf89b54 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -15,11 +15,14 @@
package integration_test
import (
+ "bytes"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) {
// Wait for the endpoint to become readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, math.MaxUint16, opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
test.checkLinkEndpoint(t, e)
@@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) {
<-serverCH
var clientAddr tcpip.FullAddress
- if v, _, err := server.Read(&clientAddr); err != nil {
+ var readBuf bytes.Buffer
+ if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("server.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
- t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ clientAddr = read.RemoteAddr
+
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ },
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
}
- if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
- t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
@@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) {
// Wait for the client endpoint to become readable.
<-clientCH
- var gotServerAddr tcpip.FullAddress
- if v, _, err := client.Read(&gotServerAddr); err != nil {
+ readBuf.Reset()
+ if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("client.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
- t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
}
- if gotServerAddr.Addr != serverAddr.Addr {
- t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d1e4a7cb7..2eb4457df 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return nil
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e5e247342..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -160,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7befcfc9b..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index cf232b508..7e81203ba 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -112,6 +112,7 @@ go_test(
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6e3c8860e..8f3981075 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -393,15 +393,28 @@ type endpoint struct {
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
e.UnlockUser()
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return nil, nil, err
}
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ return nil, nil, tcpip.ErrNotConnected
}
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
- }
- return v, tcpip.ControlMessages{}, err
-}
-
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
+
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
- var delta int
- if s.viewToDeliver >= len(views) {
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return queueAndSend()
}
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, e.hardErrorLocked()
- }
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ErrInvalidEndpointState
- }
-
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ErrClosedForReceive
- }
- return 0, tcpip.ErrWouldBlock
- }
-
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
-
- return num, nil
-}
-
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvListMu must be held.
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 5ef73ec74..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return SegSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cf60d5b53..9fa4672d7 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -17,10 +17,12 @@ package tcp_test
import (
"bytes"
"fmt"
+ "io/ioutil"
"math"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +42,64 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// endpointTester provides helper functions to test a tcpip.Endpoint.
+type endpointTester struct {
+ ep tcpip.Endpoint
+}
+
+// CheckReadError issues a read to the endpoint and checking for an error.
+func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+ t.Helper()
+ res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ if got != want {
+ t.Fatalf("ep.Read = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
+ t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
+ }
+}
+
+// CheckRead issues a read to the endpoint and checking for a success, returning
+// the data read.
+func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ return buf.Bytes()
+}
+
+// CheckReadFull reads from the endpoint for exactly count bytes.
+func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ var done int
+ for done < count {
+ res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-notifyRead:
+ case <-time.After(timeout):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ } else if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ done += res.Count
+ }
+ return buf.Bytes()
+}
+
const (
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -740,9 +800,7 @@ func TestSimpleReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -762,11 +820,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1492,14 +1546,11 @@ func TestSynSent(t *testing.T) {
t.Fatal("timed out waiting for packet to arrive")
}
+ ept := endpointTester{c.EP}
if test.reset {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ ept.CheckReadError(t, tcpip.ErrConnectionRefused)
} else {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
- }
+ ept.CheckReadError(t, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1524,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1551,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1566,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) {
})
// Receive data.
- read := make([]byte, 0, 6)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- }
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
+ read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
// Check that we received the data in proper order.
if !bytes.Equal(data, read) {
@@ -1608,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) {
rcvBufSz := math.MaxUint16
c.CreateConnected(789, 30000, rcvBufSz)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1685,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1754,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1837,17 +1865,14 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1865,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -1905,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1991,8 +2010,9 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Read the data so that the subsequent ACK from the endpoint
// grows the right edge of the window.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("got Read(nil) = %s", err)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ t.Fatalf("c.EP.Read: %s", err)
}
// Check if we have received max uint16 as our advertised
@@ -2027,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
+
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
seqNum := iss.Add(1)
@@ -2051,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- if got, want := payload, v; !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2128,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) {
),
)
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
// Receive data and check it.
- read := make([]byte, 0, rcvBufSize)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
-
+ read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
if !bytes.Equal(data, read) {
t.Fatalf("got data = %v, want = %v", read, data)
}
@@ -2569,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// we need to read at 3 packets.
sz := 0
for sz < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += len(v)
+ sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -3268,13 +3268,13 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, _, err := c.EP.Read(nil); err {
+ switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
case <-time.After(1 * time.Second):
@@ -4164,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4224,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Check that peek works.
- peekBuf := make([]byte, 10)
- n, err := c.EP.Peek([][]byte{peekBuf})
+ var peekBuf bytes.Buffer
+ res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
- peekBuf = peekBuf[:n]
- if !bytes.Equal(data, peekBuf) {
- t.Fatalf("got data = %v, want = %v", peekBuf, data)
+ if got, want := res.Count, len(data); got != want {
+ t.Fatalf("res.Count = %d, want %d", got, want)
}
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
+ if !bytes.Equal(data, peekBuf.Bytes()) {
+ t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
}
+ // Receive data.
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
// Now that we drained the queue, check that functions fail with the
// right error code.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
-
- if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -4619,17 +4614,8 @@ func TestSelfConnect(t *testing.T) {
// Read back what was written.
wq.EventUnregister(&waitEntry)
wq.EventRegister(&waitEntry, waiter.EventIn)
- rd, _, err := ep.Read(nil)
- if err != nil {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
- <-notifyCh
- rd, _, err = ep.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
+ ept := endpointTester{ep}
+ rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
if !bytes.Equal(data, rd) {
t.Fatalf("got data = %v, want = %v", rd, data)
@@ -5082,9 +5068,8 @@ func TestKeepalive(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
@@ -5173,9 +5158,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -6070,9 +6053,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
- }
+ ept := endpointTester{ep}
+ ept.CheckReadError(t, tcpip.ErrNotConnected)
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6227,7 +6209,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, _, err := c.EP.Read(nil)
+ _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6351,11 +6333,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- b, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
- totalCopied += len(b)
+ totalCopied += res.Count
}
// Invoke the moderation API. This is required for auto-tuning
@@ -7272,9 +7254,8 @@ func TestTCPUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7317,9 +7298,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7358,9 +7338,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7417,11 +7395,11 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// defaultMTU is a good enough estimate for the MSS used for this
// connection.
for read < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += len(v)
+ read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 0f9ed06cd..9e02d467d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
}
// Issue a read and we should data.
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e8bd8b04..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
@@ -282,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
@@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
}
- return p.data.ToView(), cm, nil
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8429f34b4..455b8c2aa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -18,10 +18,12 @@ import (
"bytes"
"context"
"fmt"
+ "io/ioutil"
"math/rand"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -595,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
+ var buf bytes.Buffer
+ res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, cm, err = c.ep.Read(&addr)
+ res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -621,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
}
if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
}
- // Check the peer address.
+ // Check the read result.
h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages", // ControlMessages will be checked later.
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
}
// Check the payload.
+ v := buf.Bytes()
if !bytes.Equal(payload, v) {
c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
for _, f := range checkers {
- f(c.t, cm)
+ f(c.t, res.ControlMessages)
}
c.checkEndpointReadStats(1, epstats, err)
@@ -828,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
+ if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})
}
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 06419772f..f8a0a80f2 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -204,7 +204,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
});
}
- for (int i = 0; i < kConnectAttempts; i++) {
+ for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
@@ -212,22 +212,8 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
connector.addr_len),
SyscallSucceeds());
- // Do two separate sends to ensure two segments are received. This is
- // required for netstack where read is incorrectly assuming a whole
- // segment is read when endpoint.Read() is called which is technically
- // incorrect as the syscall that invoked endpoint.Read() may only
- // consume it partially. This results in a case where a close() of
- // such a socket does not trigger a RST in netstack due to the
- // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
-
- // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
- // generates a RST.
- if (IsRunningOnGvisor()) {
- EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
- SyscallSucceedsWithValue(sizeof(i)));
- }
}
// Join threads to be sure that all connections have been counted.
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index a28ee2233..de0b8bb11 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -43,6 +43,15 @@ TEST_P(AllSocketPairTest, BasicReadWrite) {
EXPECT_EQ(data, absl::string_view(buf, 3));
}
+TEST_P(AllSocketPairTest, BasicReadWriteBadBuffer) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ const std::string data = "abc";
+ ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3),
+ SyscallSucceedsWithValue(3));
+ ASSERT_THAT(ReadFd(sockets->second_fd(), nullptr, 3),
+ SyscallFailsWithErrno(EFAULT));
+}
+
TEST_P(AllSocketPairTest, BasicSendRecv) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[512];
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 51b77ad85..a11147085 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -1507,7 +1507,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
}
ScopedThread connecting_thread([&connector, &conn_addr]() {
- for (int i = 0; i < kConnectAttempts; i++) {
+ for (int32_t i = 0; i < kConnectAttempts; i++) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
@@ -1515,22 +1515,8 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
connector.addr_len),
SyscallSucceeds());
- // Do two separate sends to ensure two segments are received. This is
- // required for netstack where read is incorrectly assuming a whole
- // segment is read when endpoint.Read() is called which is technically
- // incorrect as the syscall that invoked endpoint.Read() may only
- // consume it partially. This results in a case where a close() of
- // such a socket does not trigger a RST in netstack due to the
- // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
-
- // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
- // generates a RST.
- if (IsRunningOnGvisor()) {
- EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
- SyscallSucceedsWithValue(sizeof(i)));
- }
}
});